mirror of https://github.com/docker/compose.git
deps: fix race condition during graph traversal (#9878)
Keep track of visited nodes to prevent visiting a service multiple times. This is possible when a service depends on multiple others, as an attempt could be made to visit it from multiple parents. Signed-off-by: Milas Bowman <milas.bowman@docker.com>
This commit is contained in:
parent
f44ca01fcf
commit
616777eb4a
|
@ -37,38 +37,49 @@ const (
|
||||||
ServiceStarted
|
ServiceStarted
|
||||||
)
|
)
|
||||||
|
|
||||||
type graphTraversalConfig struct {
|
type graphTraversal struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
seen map[string]struct{}
|
||||||
|
|
||||||
extremityNodesFn func(*Graph) []*Vertex // leaves or roots
|
extremityNodesFn func(*Graph) []*Vertex // leaves or roots
|
||||||
adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
|
adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
|
||||||
filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
|
filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
|
||||||
targetServiceStatus ServiceStatus
|
targetServiceStatus ServiceStatus
|
||||||
adjacentServiceStatusToSkip ServiceStatus
|
adjacentServiceStatusToSkip ServiceStatus
|
||||||
|
|
||||||
|
visitorFn func(context.Context, string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
|
||||||
upDirectionTraversalConfig = graphTraversalConfig{
|
return &graphTraversal{
|
||||||
extremityNodesFn: leaves,
|
extremityNodesFn: leaves,
|
||||||
adjacentNodesFn: getParents,
|
adjacentNodesFn: getParents,
|
||||||
filterAdjacentByStatusFn: filterChildren,
|
filterAdjacentByStatusFn: filterChildren,
|
||||||
adjacentServiceStatusToSkip: ServiceStopped,
|
adjacentServiceStatusToSkip: ServiceStopped,
|
||||||
targetServiceStatus: ServiceStarted,
|
targetServiceStatus: ServiceStarted,
|
||||||
|
visitorFn: visitorFn,
|
||||||
}
|
}
|
||||||
downDirectionTraversalConfig = graphTraversalConfig{
|
}
|
||||||
|
|
||||||
|
func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
|
||||||
|
return &graphTraversal{
|
||||||
extremityNodesFn: roots,
|
extremityNodesFn: roots,
|
||||||
adjacentNodesFn: getChildren,
|
adjacentNodesFn: getChildren,
|
||||||
filterAdjacentByStatusFn: filterParents,
|
filterAdjacentByStatusFn: filterParents,
|
||||||
adjacentServiceStatusToSkip: ServiceStarted,
|
adjacentServiceStatusToSkip: ServiceStarted,
|
||||||
targetServiceStatus: ServiceStopped,
|
targetServiceStatus: ServiceStopped,
|
||||||
|
visitorFn: visitorFn,
|
||||||
}
|
}
|
||||||
)
|
}
|
||||||
|
|
||||||
// InDependencyOrder applies the function to the services of the project taking in account the dependency order
|
// InDependencyOrder applies the function to the services of the project taking in account the dependency order
|
||||||
func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversalConfig)) error {
|
func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
|
||||||
graph, err := NewGraph(project.Services, ServiceStopped)
|
graph, err := NewGraph(project.Services, ServiceStopped)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return visit(ctx, graph, upDirectionTraversalConfig, fn)
|
t := upDirectionTraversal(fn)
|
||||||
|
return t.visit(ctx, graph)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
|
// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
|
||||||
|
@ -77,43 +88,59 @@ func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn fu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return visit(ctx, graph, downDirectionTraversalConfig, fn)
|
t := downDirectionTraversal(fn)
|
||||||
|
return t.visit(ctx, graph)
|
||||||
}
|
}
|
||||||
|
|
||||||
func visit(ctx context.Context, g *Graph, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
|
func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
|
||||||
nodes := traversalConfig.extremityNodesFn(g)
|
nodes := t.extremityNodesFn(g)
|
||||||
|
|
||||||
eg, _ := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
eg.Go(func() error {
|
t.run(ctx, g, eg, nodes)
|
||||||
return run(ctx, g, eg, nodes, traversalConfig, fn)
|
|
||||||
})
|
|
||||||
|
|
||||||
return eg.Wait()
|
return eg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: this could be `graph.walk` or whatever
|
// Note: this could be `graph.walk` or whatever
|
||||||
func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
|
func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex) {
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
// Don't start this service yet if all of its children have
|
// Don't start this service yet if all of its children have
|
||||||
// not been started yet.
|
// not been started yet.
|
||||||
if len(traversalConfig.filterAdjacentByStatusFn(graph, node.Key, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
|
if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
node := node
|
node := node
|
||||||
|
if !t.consume(node.Key) {
|
||||||
|
// another worker already visited this node
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
err := fn(ctx, node.Service)
|
err := t.visitorFn(ctx, node.Service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.UpdateStatus(node.Key, traversalConfig.targetServiceStatus)
|
graph.UpdateStatus(node.Key, t.targetServiceStatus)
|
||||||
|
|
||||||
return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(node), traversalConfig, fn)
|
t.run(ctx, graph, eg, t.adjacentNodesFn(node))
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
func (t *graphTraversal) consume(nodeKey string) bool {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.seen == nil {
|
||||||
|
t.seen = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
if _, ok := t.seen[nodeKey]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.seen[nodeKey] = struct{}{}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graph represents project as service dependencies
|
// Graph represents project as service dependencies
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/compose-spec/compose-go/types"
|
"github.com/compose-spec/compose-go/types"
|
||||||
|
testify "github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gotest.tools/assert"
|
"gotest.tools/assert"
|
||||||
)
|
)
|
||||||
|
@ -46,6 +47,51 @@ var project = types.Project{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTraversalWithMultipleParents(t *testing.T) {
|
||||||
|
dependent := types.ServiceConfig{
|
||||||
|
Name: "dependent",
|
||||||
|
DependsOn: make(types.DependsOnConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
project := types.Project{
|
||||||
|
Services: []types.ServiceConfig{dependent},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 100; i++ {
|
||||||
|
name := fmt.Sprintf("svc_%d", i)
|
||||||
|
dependent.DependsOn[name] = types.ServiceDependency{}
|
||||||
|
|
||||||
|
svc := types.ServiceConfig{Name: name}
|
||||||
|
project.Services = append(project.Services, svc)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
svc := make(chan string, 10)
|
||||||
|
seen := make(map[string]int)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for service := range svc {
|
||||||
|
seen[service]++
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := InDependencyOrder(ctx, &project, func(ctx context.Context, service string) error {
|
||||||
|
svc <- service
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "Error during iteration")
|
||||||
|
close(svc)
|
||||||
|
<-done
|
||||||
|
|
||||||
|
testify.Len(t, seen, 101)
|
||||||
|
for svc, count := range seen {
|
||||||
|
assert.Equal(t, 1, count, "Service: %s", svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInDependencyUpCommandOrder(t *testing.T) {
|
func TestInDependencyUpCommandOrder(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
t.Cleanup(cancel)
|
t.Cleanup(cancel)
|
||||||
|
|
Loading…
Reference in New Issue