deps: fix race condition during graph traversal (#9878)

Keep track of visited nodes to prevent visiting a service multiple
times. This is possible when a service depends on multiple others,
as an attempt could be made to visit it from multiple parents.

Signed-off-by: Milas Bowman <milas.bowman@docker.com>
This commit is contained in:
Milas Bowman 2022-09-27 06:01:13 -07:00 committed by GitHub
parent f44ca01fcf
commit 616777eb4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 20 deletions

View File

@ -37,38 +37,49 @@ const (
ServiceStarted ServiceStarted
) )
type graphTraversalConfig struct { type graphTraversal struct {
mu sync.Mutex
seen map[string]struct{}
extremityNodesFn func(*Graph) []*Vertex // leaves or roots extremityNodesFn func(*Graph) []*Vertex // leaves or roots
adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents
targetServiceStatus ServiceStatus targetServiceStatus ServiceStatus
adjacentServiceStatusToSkip ServiceStatus adjacentServiceStatusToSkip ServiceStatus
visitorFn func(context.Context, string) error
} }
var ( func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
upDirectionTraversalConfig = graphTraversalConfig{ return &graphTraversal{
extremityNodesFn: leaves, extremityNodesFn: leaves,
adjacentNodesFn: getParents, adjacentNodesFn: getParents,
filterAdjacentByStatusFn: filterChildren, filterAdjacentByStatusFn: filterChildren,
adjacentServiceStatusToSkip: ServiceStopped, adjacentServiceStatusToSkip: ServiceStopped,
targetServiceStatus: ServiceStarted, targetServiceStatus: ServiceStarted,
visitorFn: visitorFn,
} }
downDirectionTraversalConfig = graphTraversalConfig{ }
func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal {
return &graphTraversal{
extremityNodesFn: roots, extremityNodesFn: roots,
adjacentNodesFn: getChildren, adjacentNodesFn: getChildren,
filterAdjacentByStatusFn: filterParents, filterAdjacentByStatusFn: filterParents,
adjacentServiceStatusToSkip: ServiceStarted, adjacentServiceStatusToSkip: ServiceStarted,
targetServiceStatus: ServiceStopped, targetServiceStatus: ServiceStopped,
visitorFn: visitorFn,
} }
) }
// InDependencyOrder applies the function to the services of the project taking in account the dependency order // InDependencyOrder applies the function to the services of the project taking in account the dependency order
func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversalConfig)) error { func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
graph, err := NewGraph(project.Services, ServiceStopped) graph, err := NewGraph(project.Services, ServiceStopped)
if err != nil { if err != nil {
return err return err
} }
return visit(ctx, graph, upDirectionTraversalConfig, fn) t := upDirectionTraversal(fn)
return t.visit(ctx, graph)
} }
// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
@ -77,43 +88,59 @@ func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn fu
if err != nil { if err != nil {
return err return err
} }
return visit(ctx, graph, downDirectionTraversalConfig, fn) t := downDirectionTraversal(fn)
return t.visit(ctx, graph)
} }
func visit(ctx context.Context, g *Graph, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error { func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
nodes := traversalConfig.extremityNodesFn(g) nodes := t.extremityNodesFn(g)
eg, _ := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { t.run(ctx, g, eg, nodes)
return run(ctx, g, eg, nodes, traversalConfig, fn)
})
return eg.Wait() return eg.Wait()
} }
// Note: this could be `graph.walk` or whatever // Note: this could be `graph.walk` or whatever
func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error { func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex) {
for _, node := range nodes { for _, node := range nodes {
// Don't start this service yet if all of its children have // Don't start this service yet if all of its children have
// not been started yet. // not been started yet.
if len(traversalConfig.filterAdjacentByStatusFn(graph, node.Key, traversalConfig.adjacentServiceStatusToSkip)) != 0 { if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 {
continue continue
} }
node := node node := node
if !t.consume(node.Key) {
// another worker already visited this node
continue
}
eg.Go(func() error { eg.Go(func() error {
err := fn(ctx, node.Service) err := t.visitorFn(ctx, node.Service)
if err != nil { if err != nil {
return err return err
} }
graph.UpdateStatus(node.Key, traversalConfig.targetServiceStatus) graph.UpdateStatus(node.Key, t.targetServiceStatus)
return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(node), traversalConfig, fn) t.run(ctx, graph, eg, t.adjacentNodesFn(node))
return nil
}) })
} }
}
return nil func (t *graphTraversal) consume(nodeKey string) bool {
t.mu.Lock()
defer t.mu.Unlock()
if t.seen == nil {
t.seen = make(map[string]struct{})
}
if _, ok := t.seen[nodeKey]; ok {
return false
}
t.seen[nodeKey] = struct{}{}
return true
} }
// Graph represents project as service dependencies // Graph represents project as service dependencies

View File

@ -22,6 +22,7 @@ import (
"testing" "testing"
"github.com/compose-spec/compose-go/types" "github.com/compose-spec/compose-go/types"
testify "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gotest.tools/assert" "gotest.tools/assert"
) )
@ -46,6 +47,51 @@ var project = types.Project{
}, },
} }
func TestTraversalWithMultipleParents(t *testing.T) {
dependent := types.ServiceConfig{
Name: "dependent",
DependsOn: make(types.DependsOnConfig),
}
project := types.Project{
Services: []types.ServiceConfig{dependent},
}
for i := 1; i <= 100; i++ {
name := fmt.Sprintf("svc_%d", i)
dependent.DependsOn[name] = types.ServiceDependency{}
svc := types.ServiceConfig{Name: name}
project.Services = append(project.Services, svc)
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
svc := make(chan string, 10)
seen := make(map[string]int)
done := make(chan struct{})
go func() {
for service := range svc {
seen[service]++
}
done <- struct{}{}
}()
err := InDependencyOrder(ctx, &project, func(ctx context.Context, service string) error {
svc <- service
return nil
})
require.NoError(t, err, "Error during iteration")
close(svc)
<-done
testify.Len(t, seen, 101)
for svc, count := range seen {
assert.Equal(t, 1, count, "Service: %s", svc)
}
}
func TestInDependencyUpCommandOrder(t *testing.T) { func TestInDependencyUpCommandOrder(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel) t.Cleanup(cancel)