diff --git a/pkg/compose/dependencies.go b/pkg/compose/dependencies.go index 75e08914f..18af8fba6 100644 --- a/pkg/compose/dependencies.go +++ b/pkg/compose/dependencies.go @@ -37,38 +37,49 @@ const ( ServiceStarted ) -type graphTraversalConfig struct { +type graphTraversal struct { + mu sync.Mutex + seen map[string]struct{} + extremityNodesFn func(*Graph) []*Vertex // leaves or roots adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren filterAdjacentByStatusFn func(*Graph, string, ServiceStatus) []*Vertex // filterChildren or filterParents targetServiceStatus ServiceStatus adjacentServiceStatusToSkip ServiceStatus + + visitorFn func(context.Context, string) error } -var ( - upDirectionTraversalConfig = graphTraversalConfig{ +func upDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal { + return &graphTraversal{ extremityNodesFn: leaves, adjacentNodesFn: getParents, filterAdjacentByStatusFn: filterChildren, adjacentServiceStatusToSkip: ServiceStopped, targetServiceStatus: ServiceStarted, + visitorFn: visitorFn, } - downDirectionTraversalConfig = graphTraversalConfig{ +} + +func downDirectionTraversal(visitorFn func(context.Context, string) error) *graphTraversal { + return &graphTraversal{ extremityNodesFn: roots, adjacentNodesFn: getChildren, filterAdjacentByStatusFn: filterParents, adjacentServiceStatusToSkip: ServiceStarted, targetServiceStatus: ServiceStopped, + visitorFn: visitorFn, } -) +} // InDependencyOrder applies the function to the services of the project taking in account the dependency order -func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversalConfig)) error { +func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error { graph, err := NewGraph(project.Services, ServiceStopped) if err != nil { return err } - return visit(ctx, graph, upDirectionTraversalConfig, fn) + t := upDirectionTraversal(fn) + return t.visit(ctx, graph) } // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies @@ -77,43 +88,59 @@ func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn fu if err != nil { return err } - return visit(ctx, graph, downDirectionTraversalConfig, fn) + t := downDirectionTraversal(fn) + return t.visit(ctx, graph) } -func visit(ctx context.Context, g *Graph, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error { - nodes := traversalConfig.extremityNodesFn(g) +func (t *graphTraversal) visit(ctx context.Context, g *Graph) error { + nodes := t.extremityNodesFn(g) - eg, _ := errgroup.WithContext(ctx) - eg.Go(func() error { - return run(ctx, g, eg, nodes, traversalConfig, fn) - }) + eg, ctx := errgroup.WithContext(ctx) + t.run(ctx, g, eg, nodes) return eg.Wait() } // Note: this could be `graph.walk` or whatever -func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error { +func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex) { for _, node := range nodes { // Don't start this service yet if all of its children have // not been started yet. - if len(traversalConfig.filterAdjacentByStatusFn(graph, node.Key, traversalConfig.adjacentServiceStatusToSkip)) != 0 { + if len(t.filterAdjacentByStatusFn(graph, node.Key, t.adjacentServiceStatusToSkip)) != 0 { continue } node := node + if !t.consume(node.Key) { + // another worker already visited this node + continue + } + eg.Go(func() error { - err := fn(ctx, node.Service) + err := t.visitorFn(ctx, node.Service) if err != nil { return err } - graph.UpdateStatus(node.Key, traversalConfig.targetServiceStatus) + graph.UpdateStatus(node.Key, t.targetServiceStatus) - return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(node), traversalConfig, fn) + t.run(ctx, graph, eg, t.adjacentNodesFn(node)) + return nil }) } +} - return nil +func (t *graphTraversal) consume(nodeKey string) bool { + t.mu.Lock() + defer t.mu.Unlock() + if t.seen == nil { + t.seen = make(map[string]struct{}) + } + if _, ok := t.seen[nodeKey]; ok { + return false + } + t.seen[nodeKey] = struct{}{} + return true } // Graph represents project as service dependencies diff --git a/pkg/compose/dependencies_test.go b/pkg/compose/dependencies_test.go index baaa98ce8..56b2a0433 100644 --- a/pkg/compose/dependencies_test.go +++ b/pkg/compose/dependencies_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/compose-spec/compose-go/types" + testify "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gotest.tools/assert" ) @@ -46,6 +47,51 @@ var project = types.Project{ }, } +func TestTraversalWithMultipleParents(t *testing.T) { + dependent := types.ServiceConfig{ + Name: "dependent", + DependsOn: make(types.DependsOnConfig), + } + + project := types.Project{ + Services: []types.ServiceConfig{dependent}, + } + + for i := 1; i <= 100; i++ { + name := fmt.Sprintf("svc_%d", i) + dependent.DependsOn[name] = types.ServiceDependency{} + + svc := types.ServiceConfig{Name: name} + project.Services = append(project.Services, svc) + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + svc := make(chan string, 10) + seen := make(map[string]int) + done := make(chan struct{}) + go func() { + for service := range svc { + seen[service]++ + } + done <- struct{}{} + }() + + err := InDependencyOrder(ctx, &project, func(ctx context.Context, service string) error { + svc <- service + return nil + }) + require.NoError(t, err, "Error during iteration") + close(svc) + <-done + + testify.Len(t, seen, 101) + for svc, count := range seen { + assert.Equal(t, 1, count, "Service: %s", svc) + } +} + func TestInDependencyUpCommandOrder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel)