diff --git a/local/compose.go b/local/compose.go index 232a8f4b6..e31c2450d 100644 --- a/local/compose.go +++ b/local/compose.go @@ -174,7 +174,7 @@ func (s *local) Down(ctx context.Context, projectName string) error { return err } - eg, ctx := errgroup.WithContext(ctx) + eg, _ := errgroup.WithContext(ctx) w := progress.ContextWriter(ctx) for _, c := range list { container := c @@ -625,7 +625,7 @@ func (s *local) ensureNetwork(ctx context.Context, n types.NetworkConfig) error StatusText: "Create", Done: false, }) - if _, err := s.containerService.apiClient.NetworkCreate(context.Background(), n.Name, createOpts); err != nil { + if _, err := s.containerService.apiClient.NetworkCreate(ctx, n.Name, createOpts); err != nil { return errors.Wrapf(err, "failed to create network %s", n.Name) } w.Event(progress.Event{ diff --git a/local/convergence.go b/local/convergence.go index 72689d77c..f17fe4a8e 100644 --- a/local/convergence.go +++ b/local/convergence.go @@ -57,7 +57,7 @@ func (s *local) ensureService(ctx context.Context, project *types.Project, servi scale := getScale(service) - eg, ctx := errgroup.WithContext(ctx) + eg, _ := errgroup.WithContext(ctx) if len(actual) < scale { next, err := nextContainerNumber(actual) if err != nil { @@ -115,7 +115,7 @@ func (s *local) ensureService(ctx context.Context, project *types.Project, servi } func (s *local) waitDependencies(ctx context.Context, project *types.Project, service types.ServiceConfig) error { - eg, ctx := errgroup.WithContext(ctx) + eg, _ := errgroup.WithContext(ctx) for dep, config := range service.DependsOn { switch config.Condition { case "service_healthy": diff --git a/local/dependencies.go b/local/dependencies.go index c2239b6da..92237370c 100644 --- a/local/dependencies.go +++ b/local/dependencies.go @@ -20,93 +20,224 @@ package local import ( "context" + "fmt" + "strings" + "sync" "github.com/compose-spec/compose-go/types" "golang.org/x/sync/errgroup" ) -func inDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { - graph := buildDependencyGraph(project.Services) +type ServiceStatus int - eg, ctx := errgroup.WithContext(ctx) - results := make(chan string) - errors := make(chan error) - scheduled := map[string]bool{} - for len(graph) > 0 { - for _, n := range graph.independents() { - service := n.service - if scheduled[service.Name] { - continue - } - eg.Go(func() error { - err := fn(ctx, service) - if err != nil { - errors <- err - return err - } - results <- service.Name - return nil - }) - scheduled[service.Name] = true - } - select { - case result := <-results: - graph.resolved(result) - case err := <-errors: - return err - } +const ( + ServiceStopped ServiceStatus = iota + ServiceStarted +) + +func inDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { + g := NewGraph(project.Services) + if b, err := g.HasCycles(); b { + return err } + + leaves := g.Leaves() + + eg, _ := errgroup.WithContext(ctx) + eg.Go(func() error { + return run(ctx, g, eg, leaves, fn) + }) + return eg.Wait() } -type dependencyGraph map[string]node - -type node struct { - service types.ServiceConfig - dependencies []string - dependent []string -} - -func (graph dependencyGraph) independents() []node { - var nodes []node - for _, node := range graph { - if len(node.dependencies) == 0 { - nodes = append(nodes, node) +// Note: this could be `graph.walk` or whatever +func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, fn func(context.Context, types.ServiceConfig) error) error { + for _, node := range nodes { + n := node + // Don't start this service yet if all of its children have + // not been started yet. + if len(graph.FilterChildren(n.Service.Name, ServiceStopped)) != 0 { + continue } + + eg.Go(func() error { + err := fn(ctx, n.Service) + if err != nil { + return err + } + + graph.UpdateStatus(n.Service.Name, ServiceStarted) + + return run(ctx, graph, eg, n.GetParents(), fn) + }) } - return nodes + + return nil } -func (graph dependencyGraph) resolved(result string) { - for _, parent := range graph[result].dependent { - node := graph[parent] - node.dependencies = remove(node.dependencies, result) - graph[parent] = node - } - delete(graph, result) +type Graph struct { + Vertices map[string]*Vertex + lock sync.RWMutex } -func buildDependencyGraph(services types.Services) dependencyGraph { - graph := dependencyGraph{} - for _, s := range services { - graph[s.Name] = node{ - service: s, - } +type Vertex struct { + Key string + Service types.ServiceConfig + Status ServiceStatus + Children map[string]*Vertex + Parents map[string]*Vertex +} + +func (v *Vertex) GetParents() []*Vertex { + var res []*Vertex + for _, p := range v.Parents { + res = append(res, p) + } + return res +} + +func NewGraph(services types.Services) *Graph { + graph := &Graph{ + lock: sync.RWMutex{}, + Vertices: map[string]*Vertex{}, + } + + for _, s := range services { + graph.AddVertex(s.Name, s) } for _, s := range services { - node := graph[s.Name] for _, name := range s.GetDependencies() { - dependency := graph[name] - node.dependencies = append(node.dependencies, name) - dependency.dependent = append(dependency.dependent, s.Name) - graph[name] = dependency + graph.AddEdge(s.Name, name) } - graph[s.Name] = node } + return graph } +// We then create a constructor function for the Vertex +func NewVertex(key string, service types.ServiceConfig) *Vertex { + return &Vertex{ + Key: key, + Service: service, + Status: ServiceStopped, + Parents: map[string]*Vertex{}, + Children: map[string]*Vertex{}, + } +} + +func (g *Graph) AddVertex(key string, service types.ServiceConfig) { + g.lock.Lock() + defer g.lock.Unlock() + + v := NewVertex(key, service) + g.Vertices[key] = v +} + +func (g *Graph) AddEdge(source string, destination string) error { + g.lock.Lock() + defer g.lock.Unlock() + + sourceVertex := g.Vertices[source] + destinationVertex := g.Vertices[destination] + + if sourceVertex == nil { + return fmt.Errorf("could not find %s", source) + } + if destinationVertex == nil { + return fmt.Errorf("could not find %s", destination) + } + + // If they are already connected + if _, ok := sourceVertex.Children[destination]; ok { + return nil + } + + sourceVertex.Children[destination] = destinationVertex + destinationVertex.Parents[source] = sourceVertex + + return nil +} + +func (g *Graph) Leaves() []*Vertex { + g.lock.Lock() + defer g.lock.Unlock() + + var res []*Vertex + for _, v := range g.Vertices { + if len(v.Children) == 0 { + res = append(res, v) + } + } + + return res +} + +func (g *Graph) UpdateStatus(key string, status ServiceStatus) { + g.lock.Lock() + defer g.lock.Unlock() + g.Vertices[key].Status = status +} + +func (g *Graph) FilterChildren(key string, status ServiceStatus) []*Vertex { + g.lock.Lock() + defer g.lock.Unlock() + + var res []*Vertex + vertex := g.Vertices[key] + + for _, child := range vertex.Children { + if child.Status == status { + res = append(res, child) + } + } + + return res +} + +func (g *Graph) HasCycles() (bool, error) { + discovered := []string{} + finished := []string{} + + for _, vertex := range g.Vertices { + path := []string{ + vertex.Key, + } + if !contains(discovered, vertex.Key) && !contains(finished, vertex.Key) { + var err error + discovered, finished, err = g.visit(vertex.Key, path, discovered, finished) + + if err != nil { + return true, err + } + } + } + + return false, nil +} + +func (g *Graph) visit(key string, path []string, discovered []string, finished []string) ([]string, []string, error) { + discovered = append(discovered, key) + + for _, v := range g.Vertices[key].Children { + path := append(path, v.Key) + if contains(discovered, v.Key) { + return nil, nil, fmt.Errorf("cycle found: %s", strings.Join(path, " -> ")) + } + + if !contains(finished, v.Key) { + if _, _, err := g.visit(v.Key, path, discovered, finished); err != nil { + return nil, nil, err + } + } + } + + discovered = remove(discovered, key) + finished = append(finished, key) + return discovered, finished, nil +} + func remove(slice []string, item string) []string { var s []string for _, i := range slice {