From c901edd65d94c50a404aeb79a3a46bdf9801a019 Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Thu, 24 Jun 2021 16:35:38 +0200 Subject: [PATCH] introduce `convergence` to hold per-service Containers and prevent race conditions Signed-off-by: Nicolas De Loof --- local/e2e/compose/ipc_test.go | 4 +- local/e2e/compose/networks_test.go | 2 +- pkg/compose/convergence.go | 363 ++++++++++++++++------------- pkg/compose/create.go | 44 +--- pkg/compose/dependencies.go | 20 +- pkg/compose/dependencies_test.go | 8 +- pkg/compose/down.go | 10 +- pkg/compose/filters.go | 4 - pkg/compose/restart.go | 25 +- pkg/compose/run.go | 20 +- pkg/compose/start.go | 6 +- pkg/compose/status.go | 111 --------- pkg/compose/stop.go | 4 +- 13 files changed, 254 insertions(+), 367 deletions(-) delete mode 100644 pkg/compose/status.go diff --git a/local/e2e/compose/ipc_test.go b/local/e2e/compose/ipc_test.go index 5ef53d466..63a11c827 100644 --- a/local/e2e/compose/ipc_test.go +++ b/local/e2e/compose/ipc_test.go @@ -59,7 +59,7 @@ func TestIPC(t *testing.T) { t.Run("down", func(t *testing.T) { _ = c.RunDockerCmd("compose", "--project-name", projectName, "down") }) - t.Run("stop ipc mode container", func(t *testing.T) { - _ = c.RunDockerCmd("stop", "ipc_mode_container") + t.Run("remove ipc mode container", func(t *testing.T) { + _ = c.RunDockerCmd("rm", "-f", "ipc_mode_container") }) } diff --git a/local/e2e/compose/networks_test.go b/local/e2e/compose/networks_test.go index 31fa9b81d..dc50d28e0 100644 --- a/local/e2e/compose/networks_test.go +++ b/local/e2e/compose/networks_test.go @@ -86,7 +86,7 @@ func TestNetworkAliassesAndLinks(t *testing.T) { }) t.Run("curl links", func(t *testing.T) { - res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "container") + res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "http://container/") assert.Assert(t, strings.Contains(res.Stdout(), "Welcome to nginx!"), res.Stdout()) }) diff --git a/pkg/compose/convergence.go b/pkg/compose/convergence.go index bae322d41..a64c78d94 100644 --- a/pkg/compose/convergence.go +++ b/pkg/compose/convergence.go @@ -21,6 +21,7 @@ import ( "fmt" "strconv" "strings" + "sync" "time" "github.com/compose-spec/compose-go/types" @@ -46,76 +47,147 @@ const ( "Remove the custom name to scale the service.\n" ) -func (s *composeService) ensureScale(ctx context.Context, project *types.Project, service types.ServiceConfig, timeout *time.Duration) (*errgroup.Group, []moby.Container, error) { - cState, err := GetContextContainerState(ctx) - if err != nil { - return nil, nil, err - } - observedState := cState.GetContainers() - actual := observedState.filter(isService(service.Name)).filter(isNotOneOff) - scale, err := getScale(service) - if err != nil { - return nil, nil, err - } - eg, _ := errgroup.WithContext(ctx) - if len(actual) < scale { - next, err := nextContainerNumber(actual) - if err != nil { - return nil, actual, err - } - missing := scale - len(actual) - for i := 0; i < missing; i++ { - number := next + i - name := getContainerName(project.Name, service, number) - eg.Go(func() error { - return s.createContainer(ctx, project, service, name, number, false, true) - }) - } - } +// convergence manages service's container lifecycle. +// Based on initially observed state, it reconciles the existing container with desired state, which might include +// re-creating container, adding or removing replicas, or starting stopped containers. +// Cross services dependencies are managed by creating services in expected order and updating `service:xx` reference +// when a service has converged, so dependent ones can be managed with resolved containers references. +type convergence struct { + service *composeService + observedState map[string]Containers +} - if len(actual) > scale { - for i := scale; i < len(actual); i++ { - container := actual[i] +func newConvergence(services []string, state Containers, s *composeService) *convergence { + observedState := map[string]Containers{} + for _, s := range services { + observedState[s] = Containers{} + } + for _, c := range state.filter(isNotOneOff) { + service := c.Labels[api.ServiceLabel] + observedState[service] = append(observedState[service], c) + } + return &convergence{ + service: s, + observedState: observedState, + } +} + +func (c *convergence) apply(ctx context.Context, project *types.Project, options api.CreateOptions) error { + return InDependencyOrder(ctx, project, func(ctx context.Context, name string) error { + service, err := project.GetService(name) + if err != nil { + return err + } + + strategy := options.RecreateDependencies + if utils.StringContains(options.Services, name) { + strategy = options.Recreate + } + err = c.ensureService(ctx, project, service, strategy, options.Inherit, options.Timeout) + if err != nil { + return err + } + + c.updateProject(project, name) + return nil + }) +} + +var mu sync.Mutex + +// updateProject updates project after service converged, so dependent services relying on `service:xx` can refer to actual containers. +func (c *convergence) updateProject(project *types.Project, service string) { + containers := c.observedState[service] + container := containers[0] + + // operation is protected by a Mutex so that we can safely update project.Services while running concurrent convergence on services + mu.Lock() + defer mu.Unlock() + + for i, s := range project.Services { + if d := getDependentServiceFromMode(s.NetworkMode); d == service { + s.NetworkMode = types.NetworkModeContainerPrefix + container.ID + } + if d := getDependentServiceFromMode(s.Ipc); d == service { + s.Ipc = types.NetworkModeContainerPrefix + container.ID + } + if d := getDependentServiceFromMode(s.Pid); d == service { + s.Pid = types.NetworkModeContainerPrefix + container.ID + } + var links []string + for _, serviceLink := range s.Links { + parts := strings.Split(serviceLink, ":") + serviceName := serviceLink + serviceAlias := "" + if len(parts) == 2 { + serviceName = parts[0] + serviceAlias = parts[1] + } + if serviceName != service { + links = append(links, serviceLink) + continue + } + for _, container := range containers { + name := getCanonicalContainerName(container) + if serviceAlias != "" { + links = append(links, + fmt.Sprintf("%s:%s", name, serviceAlias)) + } + links = append(links, + fmt.Sprintf("%s:%s", name, name), + fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container))) + } + s.Links = links + } + project.Services[i] = s + } +} + +func (c *convergence) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error { + expected, err := getScale(service) + if err != nil { + return err + } + containers := c.observedState[service.Name] + actual := len(containers) + updated := make(Containers, expected) + + eg, _ := errgroup.WithContext(ctx) + + for i, container := range containers { + if i > expected { + // Scale Down eg.Go(func() error { - err := s.apiClient.ContainerStop(ctx, container.ID, timeout) + err := c.service.apiClient.ContainerStop(ctx, container.ID, timeout) if err != nil { return err } - return s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{}) - }) - } - actual = actual[:scale] - } - return eg, actual, nil -} - -func (s *composeService) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error { - eg, actual, err := s.ensureScale(ctx, project, service, timeout) - if err != nil { - return err - } - - if recreate == api.RecreateNever { - return nil - } - - expected, err := ServiceHash(service) - if err != nil { - return err - } - - for _, container := range actual { - container := container - name := getContainerProgressName(container) - - diverged := container.Labels[api.ConfigHashLabel] != expected - if diverged || recreate == api.RecreateForce || service.Extensions[extLifecycle] == forceRecreate { - eg.Go(func() error { - return s.recreateContainer(ctx, project, service, container, inherit, timeout) + return c.service.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{}) }) continue } + if recreate == api.RecreateNever { + continue + } + // Re-create diverged containers + configHash, err := ServiceHash(service) + if err != nil { + return err + } + name := getContainerProgressName(container) + diverged := container.Labels[api.ConfigHashLabel] != configHash + if diverged || recreate == api.RecreateForce || service.Extensions[extLifecycle] == forceRecreate { + i := i + eg.Go(func() error { + recreated, err := c.service.recreateContainer(ctx, project, service, container, inherit, timeout) + updated[i] = recreated + return err + }) + continue + } + + // Enforce non-diverged containers are running w := progress.ContextWriter(ctx) switch container.State { case ContainerRunning: @@ -126,11 +198,31 @@ func (s *composeService) ensureService(ctx context.Context, project *types.Proje w.Event(progress.CreatedEvent(name)) default: eg.Go(func() error { - return s.startContainer(ctx, container) + return c.service.startContainer(ctx, container) }) } + updated[i] = container } - return eg.Wait() + + next, err := nextContainerNumber(containers) + if err != nil { + return err + } + for i := 0; i < expected-actual; i++ { + // Scale UP + number := next + i + name := getContainerName(project.Name, service, number) + eg.Go(func() error { + container, err := c.service.createContainer(ctx, project, service, name, number, false, true) + updated[actual+i-1] = container + return err + }) + continue + } + + err = eg.Wait() + c.observedState[service.Name] = updated + return err } func getContainerName(projectName string, service types.ServiceConfig, number int) string { @@ -220,51 +312,54 @@ func getScale(config types.ServiceConfig) (int, error) { return scale, err } -func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int, autoRemove bool, useNetworkAliases bool) error { +func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, + name string, number int, autoRemove bool, useNetworkAliases bool) (container moby.Container, err error) { w := progress.ContextWriter(ctx) eventName := "Container " + name w.Event(progress.CreatingEvent(eventName)) - err := s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases) + container, err = s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases) if err != nil { - return err + return } w.Event(progress.CreatedEvent(eventName)) - return nil + return } -func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, container moby.Container, inherit bool, timeout *time.Duration) error { +func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, + replaced moby.Container, inherit bool, timeout *time.Duration) (moby.Container, error) { + var created moby.Container w := progress.ContextWriter(ctx) - w.Event(progress.NewEvent(getContainerProgressName(container), progress.Working, "Recreate")) - err := s.apiClient.ContainerStop(ctx, container.ID, timeout) + w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Working, "Recreate")) + err := s.apiClient.ContainerStop(ctx, replaced.ID, timeout) if err != nil { - return err + return created, err } - name := getCanonicalContainerName(container) - tmpName := fmt.Sprintf("%s_%s", container.ID[:12], name) - err = s.apiClient.ContainerRename(ctx, container.ID, tmpName) + name := getCanonicalContainerName(replaced) + tmpName := fmt.Sprintf("%s_%s", replaced.ID[:12], name) + err = s.apiClient.ContainerRename(ctx, replaced.ID, tmpName) if err != nil { - return err + return created, err } - number, err := strconv.Atoi(container.Labels[api.ContainerNumberLabel]) + number, err := strconv.Atoi(replaced.Labels[api.ContainerNumberLabel]) if err != nil { - return err + return created, err } var inherited *moby.Container if inherit { - inherited = &container + inherited = &replaced } - err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true) + created, err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true) if err != nil { - return err + return created, err } - err = s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{}) + err = s.apiClient.ContainerRemove(ctx, replaced.ID, moby.ContainerRemoveOptions{}) if err != nil { - return err + return created, err } - w.Event(progress.NewEvent(getContainerProgressName(container), progress.Done, "Recreated")) + w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Done, "Recreated")) setDependentLifecycle(project, service.Name, forceRecreate) - return nil + return created, err } // setDependentLifecycle define the Lifecycle strategy for all services to depend on specified service @@ -291,35 +386,31 @@ func (s *composeService) startContainer(ctx context.Context, container moby.Cont return nil } -func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int, - inherit *moby.Container, - autoRemove bool, - useNetworkAliases bool) error { - cState, err := GetContextContainerState(ctx) - if err != nil { - return err - } +func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, + name string, number int, inherit *moby.Container, autoRemove bool, useNetworkAliases bool) (moby.Container, error) { + var created moby.Container containerConfig, hostConfig, networkingConfig, err := s.getCreateOptions(ctx, project, service, number, inherit, autoRemove) if err != nil { - return err + return created, err } var plat *specs.Platform if service.Platform != "" { - p, err := platforms.Parse(service.Platform) + var p specs.Platform + p, err = platforms.Parse(service.Platform) if err != nil { - return err + return created, err } plat = &p } - created, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name) + response, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name) if err != nil { - return err + return created, err } - inspectedContainer, err := s.apiClient.ContainerInspect(ctx, created.ID) + inspectedContainer, err := s.apiClient.ContainerInspect(ctx, response.ID) if err != nil { - return err + return created, err } - createdContainer := moby.Container{ + created = moby.Container{ ID: inspectedContainer.ID, Labels: inspectedContainer.Config.Labels, Names: []string{inspectedContainer.Name}, @@ -327,11 +418,7 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types Networks: inspectedContainer.NetworkSettings.Networks, }, } - cState.Add(createdContainer) - links, err := s.getLinks(ctx, service) - if err != nil { - return err - } + links := append(service.Links, service.ExternalLinks...) for _, netName := range service.NetworksByPriority() { netwrk := project.Networks[netName] cfg := service.Networks[netName] @@ -342,21 +429,21 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types aliases = append(aliases, cfg.Aliases...) } } - if val, ok := createdContainer.NetworkSettings.Networks[netwrk.Name]; ok { - if shortIDAliasExists(createdContainer.ID, val.Aliases...) { + if val, ok := created.NetworkSettings.Networks[netwrk.Name]; ok { + if shortIDAliasExists(created.ID, val.Aliases...) { continue } - err := s.apiClient.NetworkDisconnect(ctx, netwrk.Name, createdContainer.ID, false) + err = s.apiClient.NetworkDisconnect(ctx, netwrk.Name, created.ID, false) if err != nil { - return err + return created, err } } err = s.connectContainerToNetwork(ctx, created.ID, netwrk.Name, cfg, links, aliases...) if err != nil { - return err + return created, err } } - return nil + return created, err } func shortIDAliasExists(containerID string, aliases ...string) bool { @@ -395,37 +482,6 @@ func (s *composeService) connectContainerToNetwork(ctx context.Context, id strin return nil } -func (s *composeService) getLinks(ctx context.Context, service types.ServiceConfig) ([]string, error) { - cState, err := GetContextContainerState(ctx) - if err != nil { - return nil, err - } - links := []string{} - for _, serviceLink := range service.Links { - s := strings.Split(serviceLink, ":") - serviceName := serviceLink - serviceAlias := "" - if len(s) == 2 { - serviceName = s[0] - serviceAlias = s[1] - } - containers := cState.GetContainers() - depServiceContainers := containers.filter(isService(serviceName)) - for _, container := range depServiceContainers { - name := getCanonicalContainerName(container) - if serviceAlias != "" { - links = append(links, - fmt.Sprintf("%s:%s", name, serviceAlias)) - } - links = append(links, - fmt.Sprintf("%s:%s", name, name), - fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container))) - } - } - links = append(links, service.ExternalLinks...) - return links, nil -} - func (s *composeService) isServiceHealthy(ctx context.Context, project *types.Project, service string) (bool, error) { containers, err := s.getContainers(ctx, project.Name, oneOffExclude, false, service) if err != nil { @@ -503,26 +559,3 @@ func (s *composeService) startService(ctx context.Context, project *types.Projec } return eg.Wait() } - -func (s *composeService) restartService(ctx context.Context, serviceName string, timeout *time.Duration) error { - containerState, err := GetContextContainerState(ctx) - if err != nil { - return err - } - containers := containerState.GetContainers().filter(isService(serviceName)) - w := progress.ContextWriter(ctx) - eg, ctx := errgroup.WithContext(ctx) - for _, c := range containers { - container := c - eg.Go(func() error { - eventName := getContainerProgressName(container) - w.Event(progress.RestartingEvent(eventName)) - err := s.apiClient.ContainerRestart(ctx, container.ID, timeout) - if err == nil { - w.Event(progress.StartedEvent(eventName)) - } - return err - }) - } - return eg.Wait() -} diff --git a/pkg/compose/create.go b/pkg/compose/create.go index fb71a500e..ce6a01976 100644 --- a/pkg/compose/create.go +++ b/pkg/compose/create.go @@ -59,8 +59,6 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt if err != nil { return err } - containerState := NewContainersState(observedState) - ctx = context.WithValue(ctx, ContainersKey{}, containerState) err = s.ensureImagesExists(ctx, project, observedState, options.QuietPull) if err != nil { @@ -105,12 +103,7 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt prepareServicesDependsOn(project) - return InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { - if utils.StringContains(options.Services, service.Name) { - return s.ensureService(c, project, service, options.Recreate, options.Inherit, options.Timeout) - } - return s.ensureService(c, project, service, options.RecreateDependencies, options.Inherit, options.Timeout) - }) + return newConvergence(options.Services, observedState, s).apply(ctx, project, options) } func prepareVolumes(p *types.Project) error { @@ -275,12 +268,8 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project, resources := getDeployResources(service) - networkMode, err := getMode(ctx, service.Name, service.NetworkMode) - if err != nil { - return nil, nil, nil, err - } - if networkMode == "" { - networkMode = getDefaultNetworkMode(p, service) + if service.NetworkMode == "" { + service.NetworkMode = getDefaultNetworkMode(p, service) } var networkConfig *network.NetworkingConfig @@ -314,11 +303,6 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project, break //nolint:staticcheck } - ipcmode, err := getMode(ctx, service.Name, service.Ipc) - if err != nil { - return nil, nil, nil, err - } - tmpfs := map[string]string{} for _, t := range service.Tmpfs { if arr := strings.SplitN(t, ":", 2); len(arr) > 1 { @@ -342,9 +326,9 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project, Mounts: mounts, CapAdd: strslice.StrSlice(service.CapAdd), CapDrop: strslice.StrSlice(service.CapDrop), - NetworkMode: container.NetworkMode(networkMode), + NetworkMode: container.NetworkMode(service.NetworkMode), Init: service.Init, - IpcMode: container.IpcMode(ipcmode), + IpcMode: container.IpcMode(service.Ipc), ReadonlyRootfs: service.ReadOnly, RestartPolicy: getRestartPolicy(service), ShmSize: int64(service.ShmSize), @@ -913,24 +897,6 @@ func getAliases(s types.ServiceConfig, c *types.ServiceNetworkConfig) []string { return aliases } -func getMode(ctx context.Context, serviceName string, mode string) (string, error) { - cState, err := GetContextContainerState(ctx) - if err != nil { - return "", nil - } - observedState := cState.GetContainers() - depService := getDependentServiceFromMode(mode) - if depService != "" { - depServiceContainers := observedState.filter(isService(depService)) - if len(depServiceContainers) > 0 { - return types.NetworkModeContainerPrefix + depServiceContainers[0].ID, nil - } - return "", fmt.Errorf(`no containers started for %q in service %q -> %v`, - mode, serviceName, observedState) - } - return mode, nil -} - func getNetworksForService(s types.ServiceConfig) map[string]*types.ServiceNetworkConfig { if len(s.Networks) > 0 { return s.Networks diff --git a/pkg/compose/dependencies.go b/pkg/compose/dependencies.go index d278ff3e8..42cb8f649 100644 --- a/pkg/compose/dependencies.go +++ b/pkg/compose/dependencies.go @@ -63,16 +63,16 @@ var ( ) // InDependencyOrder applies the function to the services of the project taking in account the dependency order -func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { +func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error { return visit(ctx, project, upDirectionTraversalConfig, fn, ServiceStopped) } // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies -func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { +func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error { return visit(ctx, project, downDirectionTraversalConfig, fn, ServiceStarted) } -func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error, initialStatus ServiceStatus) error { +func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, string) error, initialStatus ServiceStatus) error { g := NewGraph(project.Services, initialStatus) if b, err := g.HasCycles(); b { return err @@ -89,12 +89,12 @@ func visit(ctx context.Context, project *types.Project, traversalConfig graphTra } // Note: this could be `graph.walk` or whatever -func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error) error { +func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error { for _, node := range nodes { n := node // Don't start this service yet if all of its children have // not been started yet. - if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service.Name, traversalConfig.adjacentServiceStatusToSkip)) != 0 { + if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service, traversalConfig.adjacentServiceStatusToSkip)) != 0 { continue } @@ -104,7 +104,7 @@ func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, return err } - graph.UpdateStatus(n.Service.Name, traversalConfig.targetServiceStatus) + graph.UpdateStatus(n.Service, traversalConfig.targetServiceStatus) return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(n), traversalConfig, fn) }) @@ -122,7 +122,7 @@ type Graph struct { // Vertex represents a service in the dependencies structure type Vertex struct { Key string - Service types.ServiceConfig + Service string Status ServiceStatus Children map[string]*Vertex Parents map[string]*Vertex @@ -162,7 +162,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph { } for _, s := range services { - graph.AddVertex(s.Name, s, initialStatus) + graph.AddVertex(s.Name, s.Name, initialStatus) } for _, s := range services { @@ -175,7 +175,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph { } // NewVertex is the constructor function for the Vertex -func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) *Vertex { +func NewVertex(key string, service string, initialStatus ServiceStatus) *Vertex { return &Vertex{ Key: key, Service: service, @@ -186,7 +186,7 @@ func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceSta } // AddVertex adds a vertex to the Graph -func (g *Graph) AddVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) { +func (g *Graph) AddVertex(key string, service string, initialStatus ServiceStatus) { g.lock.Lock() defer g.lock.Unlock() diff --git a/pkg/compose/dependencies_test.go b/pkg/compose/dependencies_test.go index 841274b76..5d5871a54 100644 --- a/pkg/compose/dependencies_test.go +++ b/pkg/compose/dependencies_test.go @@ -47,8 +47,8 @@ var project = types.Project{ func TestInDependencyUpCommandOrder(t *testing.T) { order := make(chan string) //nolint:errcheck, unparam - go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error { - order <- config.Name + go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error { + order <- config return nil }) assert.Equal(t, <-order, "test3") @@ -59,8 +59,8 @@ func TestInDependencyUpCommandOrder(t *testing.T) { func TestInDependencyReverseDownCommandOrder(t *testing.T) { order := make(chan string) //nolint:errcheck, unparam - go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error { - order <- config.Name + go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error { + order <- config return nil }) assert.Equal(t, <-order, "test1") diff --git a/pkg/compose/down.go b/pkg/compose/down.go index e56b77a6e..c125913d8 100644 --- a/pkg/compose/down.go +++ b/pkg/compose/down.go @@ -50,7 +50,6 @@ func (s *composeService) down(ctx context.Context, projectName string, options a if err != nil { return err } - ctx = context.WithValue(ctx, ContainersKey{}, NewContainersState(containers)) if options.Project == nil { project, err := s.projectFromContainerLabels(containers, projectName) @@ -64,8 +63,8 @@ func (s *composeService) down(ctx context.Context, projectName string, options a resourceToRemove = true } - err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service types.ServiceConfig) error { - serviceContainers := containers.filter(isService(service.Name)) + err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service string) error { + serviceContainers := containers.filter(isService(service)) err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes) return err }) @@ -236,11 +235,6 @@ func (s *composeService) removeContainers(ctx context.Context, w progress.Writer w.Event(progress.ErrorMessageEvent(eventName, "Error while Removing")) return err } - contextContainerState, err := GetContextContainerState(ctx) - if err != nil { - return err - } - contextContainerState.Remove(toDelete.ID) w.Event(progress.RemovedEvent(eventName)) return nil }) diff --git a/pkg/compose/filters.go b/pkg/compose/filters.go index 178daa10a..317353cfc 100644 --- a/pkg/compose/filters.go +++ b/pkg/compose/filters.go @@ -31,10 +31,6 @@ func serviceFilter(serviceName string) filters.KeyValuePair { return filters.Arg("label", fmt.Sprintf("%s=%s", api.ServiceLabel, serviceName)) } -func slugFilter(slug string) filters.KeyValuePair { - return filters.Arg("label", fmt.Sprintf("%s=%s", api.SlugLabel, slug)) -} - func oneOffFilter(b bool) filters.KeyValuePair { v := "False" if b { diff --git a/pkg/compose/restart.go b/pkg/compose/restart.go index 241b78705..05000e6e9 100644 --- a/pkg/compose/restart.go +++ b/pkg/compose/restart.go @@ -21,6 +21,7 @@ import ( "github.com/compose-spec/compose-go/types" "github.com/docker/compose-cli/pkg/api" + "golang.org/x/sync/errgroup" "github.com/docker/compose-cli/pkg/progress" "github.com/docker/compose-cli/pkg/utils" @@ -33,7 +34,7 @@ func (s *composeService) Restart(ctx context.Context, project *types.Project, op } func (s *composeService) restart(ctx context.Context, project *types.Project, options api.RestartOptions) error { - ctx, err := s.getUpdatedContainersStateContext(ctx, project.Name) + observedState, err := s.getContainers(ctx, project.Name, oneOffInclude, true) if err != nil { return err } @@ -42,11 +43,25 @@ func (s *composeService) restart(ctx context.Context, project *types.Project, op options.Services = project.ServiceNames() } - err = InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { - if utils.StringContains(options.Services, service.Name) { - return s.restartService(ctx, service.Name, options.Timeout) + w := progress.ContextWriter(ctx) + err = InDependencyOrder(ctx, project, func(c context.Context, service string) error { + if !utils.StringContains(options.Services, service) { + return nil } - return nil + eg, ctx := errgroup.WithContext(ctx) + for _, c := range observedState.filter(isService(service)) { + container := c + eg.Go(func() error { + eventName := getContainerProgressName(container) + w.Event(progress.RestartingEvent(eventName)) + err := s.apiClient.ContainerRestart(ctx, container.ID, options.Timeout) + if err == nil { + w.Event(progress.StartedEvent(eventName)) + } + return err + }) + } + return eg.Wait() }) if err != nil { return err diff --git a/pkg/compose/run.go b/pkg/compose/run.go index 947e38b57..be78e55c8 100644 --- a/pkg/compose/run.go +++ b/pkg/compose/run.go @@ -25,7 +25,6 @@ import ( "github.com/compose-spec/compose-go/types" moby "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/filters" "github.com/docker/docker/pkg/stringid" ) @@ -34,8 +33,6 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types. if err != nil { return 0, err } - containerState := NewContainersState(observedState) - ctx = context.WithValue(ctx, ContainersKey{}, containerState) service, err := project.GetService(opts.Service) if err != nil { @@ -63,10 +60,11 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types. if err := s.waitDependencies(ctx, project, service); err != nil { return 0, err } - if err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases); err != nil { + created, err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases) + if err != nil { return 0, err } - containerID := service.ContainerName + containerID := created.ID if opts.Detach { err := s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{}) @@ -77,21 +75,13 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types. return 0, nil } - containers, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{ - Filters: filters.NewArgs(slugFilter(slug)), - All: true, - }) - if err != nil { - return 0, err - } - oneoffContainer := containers[0] - restore, err := s.attachContainerStreams(ctx, oneoffContainer.ID, service.Tty, opts.Reader, opts.Writer) + restore, err := s.attachContainerStreams(ctx, containerID, service.Tty, opts.Reader, opts.Writer) if err != nil { return 0, err } defer restore() - statusC, errC := s.apiClient.ContainerWait(context.Background(), oneoffContainer.ID, container.WaitConditionNextExit) + statusC, errC := s.apiClient.ContainerWait(context.Background(), containerID, container.WaitConditionNextExit) err = s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{}) if err != nil { diff --git a/pkg/compose/start.go b/pkg/compose/start.go index a882a1e76..335230de6 100644 --- a/pkg/compose/start.go +++ b/pkg/compose/start.go @@ -53,7 +53,11 @@ func (s *composeService) start(ctx context.Context, project *types.Project, opti }) } - err := InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { + err := InDependencyOrder(ctx, project, func(c context.Context, name string) error { + service, err := project.GetService(name) + if err != nil { + return err + } return s.startService(ctx, project, service) }) if err != nil { diff --git a/pkg/compose/status.go b/pkg/compose/status.go deleted file mode 100644 index 32af151c2..000000000 --- a/pkg/compose/status.go +++ /dev/null @@ -1,111 +0,0 @@ -/* - Copyright 2020 Docker Compose CLI authors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package compose - -import ( - "context" - - "github.com/docker/docker/api/types" - "github.com/pkg/errors" -) - -// ContainersKey is the context key to access context value os a ContainersStatus -type ContainersKey struct{} - -// ContainersState state management interface -type ContainersState interface { - Get(string) *types.Container - GetContainers() Containers - Add(c types.Container) - AddAll(cs Containers) - Remove(string) types.Container -} - -// NewContainersState creates a new container state manager -func NewContainersState(cs Containers) ContainersState { - s := containersState{ - observedContainers: &cs, - } - return &s -} - -// ContainersStatus works as a collection container for the observed containers -type containersState struct { - observedContainers *Containers -} - -func (s *containersState) AddAll(cs Containers) { - for _, c := range cs { - lValue := append(*s.observedContainers, c) - s.observedContainers = &lValue - } -} - -func (s *containersState) Add(c types.Container) { - if s.Get(c.ID) == nil { - lValue := append(*s.observedContainers, c) - s.observedContainers = &lValue - } -} - -func (s *containersState) Remove(id string) types.Container { - var c types.Container - var newObserved Containers - for _, o := range *s.observedContainers { - if o.ID == id { - c = o - continue - } - newObserved = append(newObserved, o) - } - s.observedContainers = &newObserved - return c -} - -func (s *containersState) Get(id string) *types.Container { - for _, o := range *s.observedContainers { - if id == o.ID { - return &o - } - } - return nil -} - -func (s *containersState) GetContainers() Containers { - if s.observedContainers != nil && *s.observedContainers != nil { - return *s.observedContainers - } - return make(Containers, 0) -} - -// GetContextContainerState gets the container state manager -func GetContextContainerState(ctx context.Context) (ContainersState, error) { - cState, ok := ctx.Value(ContainersKey{}).(*containersState) - if !ok { - return nil, errors.New("containers' containersState not available in context") - } - return cState, nil -} - -func (s composeService) getUpdatedContainersStateContext(ctx context.Context, projectName string) (context.Context, error) { - observedState, err := s.getContainers(ctx, projectName, oneOffInclude, true) - if err != nil { - return nil, err - } - containerState := NewContainersState(observedState) - return context.WithValue(ctx, ContainersKey{}, containerState), nil -} diff --git a/pkg/compose/stop.go b/pkg/compose/stop.go index 056a1f297..f292305ad 100644 --- a/pkg/compose/stop.go +++ b/pkg/compose/stop.go @@ -44,7 +44,7 @@ func (s *composeService) stop(ctx context.Context, project *types.Project, optio return err } - return InReverseDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { - return s.stopContainers(ctx, w, containers.filter(isService(service.Name)), options.Timeout) + return InReverseDependencyOrder(ctx, project, func(c context.Context, service string) error { + return s.stopContainers(ctx, w, containers.filter(isService(service)), options.Timeout) }) }