Merge pull request #1841 from ndeloof/convergence

This commit is contained in:
Nicolas De loof 2021-06-28 08:47:17 +02:00 committed by GitHub
commit d20c3b0e22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 254 additions and 367 deletions

View File

@ -59,7 +59,7 @@ func TestIPC(t *testing.T) {
t.Run("down", func(t *testing.T) { t.Run("down", func(t *testing.T) {
_ = c.RunDockerCmd("compose", "--project-name", projectName, "down") _ = c.RunDockerCmd("compose", "--project-name", projectName, "down")
}) })
t.Run("stop ipc mode container", func(t *testing.T) { t.Run("remove ipc mode container", func(t *testing.T) {
_ = c.RunDockerCmd("stop", "ipc_mode_container") _ = c.RunDockerCmd("rm", "-f", "ipc_mode_container")
}) })
} }

View File

@ -86,7 +86,7 @@ func TestNetworkAliassesAndLinks(t *testing.T) {
}) })
t.Run("curl links", func(t *testing.T) { t.Run("curl links", func(t *testing.T) {
res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "container") res := c.RunDockerCmd("compose", "-f", "./fixtures/network-alias/compose.yaml", "--project-name", projectName, "exec", "-T", "container1", "curl", "http://container/")
assert.Assert(t, strings.Contains(res.Stdout(), "Welcome to nginx!"), res.Stdout()) assert.Assert(t, strings.Contains(res.Stdout(), "Welcome to nginx!"), res.Stdout())
}) })

View File

@ -21,6 +21,7 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/compose-spec/compose-go/types" "github.com/compose-spec/compose-go/types"
@ -46,76 +47,147 @@ const (
"Remove the custom name to scale the service.\n" "Remove the custom name to scale the service.\n"
) )
func (s *composeService) ensureScale(ctx context.Context, project *types.Project, service types.ServiceConfig, timeout *time.Duration) (*errgroup.Group, []moby.Container, error) { // convergence manages service's container lifecycle.
cState, err := GetContextContainerState(ctx) // Based on initially observed state, it reconciles the existing container with desired state, which might include
if err != nil { // re-creating container, adding or removing replicas, or starting stopped containers.
return nil, nil, err // Cross services dependencies are managed by creating services in expected order and updating `service:xx` reference
} // when a service has converged, so dependent ones can be managed with resolved containers references.
observedState := cState.GetContainers() type convergence struct {
actual := observedState.filter(isService(service.Name)).filter(isNotOneOff) service *composeService
scale, err := getScale(service) observedState map[string]Containers
if err != nil { }
return nil, nil, err
}
eg, _ := errgroup.WithContext(ctx)
if len(actual) < scale {
next, err := nextContainerNumber(actual)
if err != nil {
return nil, actual, err
}
missing := scale - len(actual)
for i := 0; i < missing; i++ {
number := next + i
name := getContainerName(project.Name, service, number)
eg.Go(func() error {
return s.createContainer(ctx, project, service, name, number, false, true)
})
}
}
if len(actual) > scale { func newConvergence(services []string, state Containers, s *composeService) *convergence {
for i := scale; i < len(actual); i++ { observedState := map[string]Containers{}
container := actual[i] for _, s := range services {
observedState[s] = Containers{}
}
for _, c := range state.filter(isNotOneOff) {
service := c.Labels[api.ServiceLabel]
observedState[service] = append(observedState[service], c)
}
return &convergence{
service: s,
observedState: observedState,
}
}
func (c *convergence) apply(ctx context.Context, project *types.Project, options api.CreateOptions) error {
return InDependencyOrder(ctx, project, func(ctx context.Context, name string) error {
service, err := project.GetService(name)
if err != nil {
return err
}
strategy := options.RecreateDependencies
if utils.StringContains(options.Services, name) {
strategy = options.Recreate
}
err = c.ensureService(ctx, project, service, strategy, options.Inherit, options.Timeout)
if err != nil {
return err
}
c.updateProject(project, name)
return nil
})
}
var mu sync.Mutex
// updateProject updates project after service converged, so dependent services relying on `service:xx` can refer to actual containers.
func (c *convergence) updateProject(project *types.Project, service string) {
containers := c.observedState[service]
container := containers[0]
// operation is protected by a Mutex so that we can safely update project.Services while running concurrent convergence on services
mu.Lock()
defer mu.Unlock()
for i, s := range project.Services {
if d := getDependentServiceFromMode(s.NetworkMode); d == service {
s.NetworkMode = types.NetworkModeContainerPrefix + container.ID
}
if d := getDependentServiceFromMode(s.Ipc); d == service {
s.Ipc = types.NetworkModeContainerPrefix + container.ID
}
if d := getDependentServiceFromMode(s.Pid); d == service {
s.Pid = types.NetworkModeContainerPrefix + container.ID
}
var links []string
for _, serviceLink := range s.Links {
parts := strings.Split(serviceLink, ":")
serviceName := serviceLink
serviceAlias := ""
if len(parts) == 2 {
serviceName = parts[0]
serviceAlias = parts[1]
}
if serviceName != service {
links = append(links, serviceLink)
continue
}
for _, container := range containers {
name := getCanonicalContainerName(container)
if serviceAlias != "" {
links = append(links,
fmt.Sprintf("%s:%s", name, serviceAlias))
}
links = append(links,
fmt.Sprintf("%s:%s", name, name),
fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container)))
}
s.Links = links
}
project.Services[i] = s
}
}
func (c *convergence) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error {
expected, err := getScale(service)
if err != nil {
return err
}
containers := c.observedState[service.Name]
actual := len(containers)
updated := make(Containers, expected)
eg, _ := errgroup.WithContext(ctx)
for i, container := range containers {
if i > expected {
// Scale Down
eg.Go(func() error { eg.Go(func() error {
err := s.apiClient.ContainerStop(ctx, container.ID, timeout) err := c.service.apiClient.ContainerStop(ctx, container.ID, timeout)
if err != nil { if err != nil {
return err return err
} }
return s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{}) return c.service.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{})
})
}
actual = actual[:scale]
}
return eg, actual, nil
}
func (s *composeService) ensureService(ctx context.Context, project *types.Project, service types.ServiceConfig, recreate string, inherit bool, timeout *time.Duration) error {
eg, actual, err := s.ensureScale(ctx, project, service, timeout)
if err != nil {
return err
}
if recreate == api.RecreateNever {
return nil
}
expected, err := ServiceHash(service)
if err != nil {
return err
}
for _, container := range actual {
container := container
name := getContainerProgressName(container)
diverged := container.Labels[api.ConfigHashLabel] != expected
if diverged || recreate == api.RecreateForce || service.Extensions[extLifecycle] == forceRecreate {
eg.Go(func() error {
return s.recreateContainer(ctx, project, service, container, inherit, timeout)
}) })
continue continue
} }
if recreate == api.RecreateNever {
continue
}
// Re-create diverged containers
configHash, err := ServiceHash(service)
if err != nil {
return err
}
name := getContainerProgressName(container)
diverged := container.Labels[api.ConfigHashLabel] != configHash
if diverged || recreate == api.RecreateForce || service.Extensions[extLifecycle] == forceRecreate {
i := i
eg.Go(func() error {
recreated, err := c.service.recreateContainer(ctx, project, service, container, inherit, timeout)
updated[i] = recreated
return err
})
continue
}
// Enforce non-diverged containers are running
w := progress.ContextWriter(ctx) w := progress.ContextWriter(ctx)
switch container.State { switch container.State {
case ContainerRunning: case ContainerRunning:
@ -126,11 +198,31 @@ func (s *composeService) ensureService(ctx context.Context, project *types.Proje
w.Event(progress.CreatedEvent(name)) w.Event(progress.CreatedEvent(name))
default: default:
eg.Go(func() error { eg.Go(func() error {
return s.startContainer(ctx, container) return c.service.startContainer(ctx, container)
}) })
} }
updated[i] = container
} }
return eg.Wait()
next, err := nextContainerNumber(containers)
if err != nil {
return err
}
for i := 0; i < expected-actual; i++ {
// Scale UP
number := next + i
name := getContainerName(project.Name, service, number)
eg.Go(func() error {
container, err := c.service.createContainer(ctx, project, service, name, number, false, true)
updated[actual+i-1] = container
return err
})
continue
}
err = eg.Wait()
c.observedState[service.Name] = updated
return err
} }
func getContainerName(projectName string, service types.ServiceConfig, number int) string { func getContainerName(projectName string, service types.ServiceConfig, number int) string {
@ -220,51 +312,54 @@ func getScale(config types.ServiceConfig) (int, error) {
return scale, err return scale, err
} }
func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int, autoRemove bool, useNetworkAliases bool) error { func (s *composeService) createContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
name string, number int, autoRemove bool, useNetworkAliases bool) (container moby.Container, err error) {
w := progress.ContextWriter(ctx) w := progress.ContextWriter(ctx)
eventName := "Container " + name eventName := "Container " + name
w.Event(progress.CreatingEvent(eventName)) w.Event(progress.CreatingEvent(eventName))
err := s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases) container, err = s.createMobyContainer(ctx, project, service, name, number, nil, autoRemove, useNetworkAliases)
if err != nil { if err != nil {
return err return
} }
w.Event(progress.CreatedEvent(eventName)) w.Event(progress.CreatedEvent(eventName))
return nil return
} }
func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, container moby.Container, inherit bool, timeout *time.Duration) error { func (s *composeService) recreateContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
replaced moby.Container, inherit bool, timeout *time.Duration) (moby.Container, error) {
var created moby.Container
w := progress.ContextWriter(ctx) w := progress.ContextWriter(ctx)
w.Event(progress.NewEvent(getContainerProgressName(container), progress.Working, "Recreate")) w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Working, "Recreate"))
err := s.apiClient.ContainerStop(ctx, container.ID, timeout) err := s.apiClient.ContainerStop(ctx, replaced.ID, timeout)
if err != nil { if err != nil {
return err return created, err
} }
name := getCanonicalContainerName(container) name := getCanonicalContainerName(replaced)
tmpName := fmt.Sprintf("%s_%s", container.ID[:12], name) tmpName := fmt.Sprintf("%s_%s", replaced.ID[:12], name)
err = s.apiClient.ContainerRename(ctx, container.ID, tmpName) err = s.apiClient.ContainerRename(ctx, replaced.ID, tmpName)
if err != nil { if err != nil {
return err return created, err
} }
number, err := strconv.Atoi(container.Labels[api.ContainerNumberLabel]) number, err := strconv.Atoi(replaced.Labels[api.ContainerNumberLabel])
if err != nil { if err != nil {
return err return created, err
} }
var inherited *moby.Container var inherited *moby.Container
if inherit { if inherit {
inherited = &container inherited = &replaced
} }
err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true) created, err = s.createMobyContainer(ctx, project, service, name, number, inherited, false, true)
if err != nil { if err != nil {
return err return created, err
} }
err = s.apiClient.ContainerRemove(ctx, container.ID, moby.ContainerRemoveOptions{}) err = s.apiClient.ContainerRemove(ctx, replaced.ID, moby.ContainerRemoveOptions{})
if err != nil { if err != nil {
return err return created, err
} }
w.Event(progress.NewEvent(getContainerProgressName(container), progress.Done, "Recreated")) w.Event(progress.NewEvent(getContainerProgressName(replaced), progress.Done, "Recreated"))
setDependentLifecycle(project, service.Name, forceRecreate) setDependentLifecycle(project, service.Name, forceRecreate)
return nil return created, err
} }
// setDependentLifecycle define the Lifecycle strategy for all services to depend on specified service // setDependentLifecycle define the Lifecycle strategy for all services to depend on specified service
@ -291,35 +386,31 @@ func (s *composeService) startContainer(ctx context.Context, container moby.Cont
return nil return nil
} }
func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig, name string, number int, func (s *composeService) createMobyContainer(ctx context.Context, project *types.Project, service types.ServiceConfig,
inherit *moby.Container, name string, number int, inherit *moby.Container, autoRemove bool, useNetworkAliases bool) (moby.Container, error) {
autoRemove bool, var created moby.Container
useNetworkAliases bool) error {
cState, err := GetContextContainerState(ctx)
if err != nil {
return err
}
containerConfig, hostConfig, networkingConfig, err := s.getCreateOptions(ctx, project, service, number, inherit, autoRemove) containerConfig, hostConfig, networkingConfig, err := s.getCreateOptions(ctx, project, service, number, inherit, autoRemove)
if err != nil { if err != nil {
return err return created, err
} }
var plat *specs.Platform var plat *specs.Platform
if service.Platform != "" { if service.Platform != "" {
p, err := platforms.Parse(service.Platform) var p specs.Platform
p, err = platforms.Parse(service.Platform)
if err != nil { if err != nil {
return err return created, err
} }
plat = &p plat = &p
} }
created, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name) response, err := s.apiClient.ContainerCreate(ctx, containerConfig, hostConfig, networkingConfig, plat, name)
if err != nil { if err != nil {
return err return created, err
} }
inspectedContainer, err := s.apiClient.ContainerInspect(ctx, created.ID) inspectedContainer, err := s.apiClient.ContainerInspect(ctx, response.ID)
if err != nil { if err != nil {
return err return created, err
} }
createdContainer := moby.Container{ created = moby.Container{
ID: inspectedContainer.ID, ID: inspectedContainer.ID,
Labels: inspectedContainer.Config.Labels, Labels: inspectedContainer.Config.Labels,
Names: []string{inspectedContainer.Name}, Names: []string{inspectedContainer.Name},
@ -327,11 +418,7 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types
Networks: inspectedContainer.NetworkSettings.Networks, Networks: inspectedContainer.NetworkSettings.Networks,
}, },
} }
cState.Add(createdContainer) links := append(service.Links, service.ExternalLinks...)
links, err := s.getLinks(ctx, service)
if err != nil {
return err
}
for _, netName := range service.NetworksByPriority() { for _, netName := range service.NetworksByPriority() {
netwrk := project.Networks[netName] netwrk := project.Networks[netName]
cfg := service.Networks[netName] cfg := service.Networks[netName]
@ -342,21 +429,21 @@ func (s *composeService) createMobyContainer(ctx context.Context, project *types
aliases = append(aliases, cfg.Aliases...) aliases = append(aliases, cfg.Aliases...)
} }
} }
if val, ok := createdContainer.NetworkSettings.Networks[netwrk.Name]; ok { if val, ok := created.NetworkSettings.Networks[netwrk.Name]; ok {
if shortIDAliasExists(createdContainer.ID, val.Aliases...) { if shortIDAliasExists(created.ID, val.Aliases...) {
continue continue
} }
err := s.apiClient.NetworkDisconnect(ctx, netwrk.Name, createdContainer.ID, false) err = s.apiClient.NetworkDisconnect(ctx, netwrk.Name, created.ID, false)
if err != nil { if err != nil {
return err return created, err
} }
} }
err = s.connectContainerToNetwork(ctx, created.ID, netwrk.Name, cfg, links, aliases...) err = s.connectContainerToNetwork(ctx, created.ID, netwrk.Name, cfg, links, aliases...)
if err != nil { if err != nil {
return err return created, err
} }
} }
return nil return created, err
} }
func shortIDAliasExists(containerID string, aliases ...string) bool { func shortIDAliasExists(containerID string, aliases ...string) bool {
@ -395,37 +482,6 @@ func (s *composeService) connectContainerToNetwork(ctx context.Context, id strin
return nil return nil
} }
func (s *composeService) getLinks(ctx context.Context, service types.ServiceConfig) ([]string, error) {
cState, err := GetContextContainerState(ctx)
if err != nil {
return nil, err
}
links := []string{}
for _, serviceLink := range service.Links {
s := strings.Split(serviceLink, ":")
serviceName := serviceLink
serviceAlias := ""
if len(s) == 2 {
serviceName = s[0]
serviceAlias = s[1]
}
containers := cState.GetContainers()
depServiceContainers := containers.filter(isService(serviceName))
for _, container := range depServiceContainers {
name := getCanonicalContainerName(container)
if serviceAlias != "" {
links = append(links,
fmt.Sprintf("%s:%s", name, serviceAlias))
}
links = append(links,
fmt.Sprintf("%s:%s", name, name),
fmt.Sprintf("%s:%s", name, getContainerNameWithoutProject(container)))
}
}
links = append(links, service.ExternalLinks...)
return links, nil
}
func (s *composeService) isServiceHealthy(ctx context.Context, project *types.Project, service string) (bool, error) { func (s *composeService) isServiceHealthy(ctx context.Context, project *types.Project, service string) (bool, error) {
containers, err := s.getContainers(ctx, project.Name, oneOffExclude, false, service) containers, err := s.getContainers(ctx, project.Name, oneOffExclude, false, service)
if err != nil { if err != nil {
@ -503,26 +559,3 @@ func (s *composeService) startService(ctx context.Context, project *types.Projec
} }
return eg.Wait() return eg.Wait()
} }
func (s *composeService) restartService(ctx context.Context, serviceName string, timeout *time.Duration) error {
containerState, err := GetContextContainerState(ctx)
if err != nil {
return err
}
containers := containerState.GetContainers().filter(isService(serviceName))
w := progress.ContextWriter(ctx)
eg, ctx := errgroup.WithContext(ctx)
for _, c := range containers {
container := c
eg.Go(func() error {
eventName := getContainerProgressName(container)
w.Event(progress.RestartingEvent(eventName))
err := s.apiClient.ContainerRestart(ctx, container.ID, timeout)
if err == nil {
w.Event(progress.StartedEvent(eventName))
}
return err
})
}
return eg.Wait()
}

View File

@ -59,8 +59,6 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt
if err != nil { if err != nil {
return err return err
} }
containerState := NewContainersState(observedState)
ctx = context.WithValue(ctx, ContainersKey{}, containerState)
err = s.ensureImagesExists(ctx, project, observedState, options.QuietPull) err = s.ensureImagesExists(ctx, project, observedState, options.QuietPull)
if err != nil { if err != nil {
@ -105,12 +103,7 @@ func (s *composeService) create(ctx context.Context, project *types.Project, opt
prepareServicesDependsOn(project) prepareServicesDependsOn(project)
return InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { return newConvergence(options.Services, observedState, s).apply(ctx, project, options)
if utils.StringContains(options.Services, service.Name) {
return s.ensureService(c, project, service, options.Recreate, options.Inherit, options.Timeout)
}
return s.ensureService(c, project, service, options.RecreateDependencies, options.Inherit, options.Timeout)
})
} }
func prepareVolumes(p *types.Project) error { func prepareVolumes(p *types.Project) error {
@ -275,12 +268,8 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
resources := getDeployResources(service) resources := getDeployResources(service)
networkMode, err := getMode(ctx, service.Name, service.NetworkMode) if service.NetworkMode == "" {
if err != nil { service.NetworkMode = getDefaultNetworkMode(p, service)
return nil, nil, nil, err
}
if networkMode == "" {
networkMode = getDefaultNetworkMode(p, service)
} }
var networkConfig *network.NetworkingConfig var networkConfig *network.NetworkingConfig
@ -314,11 +303,6 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
break //nolint:staticcheck break //nolint:staticcheck
} }
ipcmode, err := getMode(ctx, service.Name, service.Ipc)
if err != nil {
return nil, nil, nil, err
}
tmpfs := map[string]string{} tmpfs := map[string]string{}
for _, t := range service.Tmpfs { for _, t := range service.Tmpfs {
if arr := strings.SplitN(t, ":", 2); len(arr) > 1 { if arr := strings.SplitN(t, ":", 2); len(arr) > 1 {
@ -342,9 +326,9 @@ func (s *composeService) getCreateOptions(ctx context.Context, p *types.Project,
Mounts: mounts, Mounts: mounts,
CapAdd: strslice.StrSlice(service.CapAdd), CapAdd: strslice.StrSlice(service.CapAdd),
CapDrop: strslice.StrSlice(service.CapDrop), CapDrop: strslice.StrSlice(service.CapDrop),
NetworkMode: container.NetworkMode(networkMode), NetworkMode: container.NetworkMode(service.NetworkMode),
Init: service.Init, Init: service.Init,
IpcMode: container.IpcMode(ipcmode), IpcMode: container.IpcMode(service.Ipc),
ReadonlyRootfs: service.ReadOnly, ReadonlyRootfs: service.ReadOnly,
RestartPolicy: getRestartPolicy(service), RestartPolicy: getRestartPolicy(service),
ShmSize: int64(service.ShmSize), ShmSize: int64(service.ShmSize),
@ -913,24 +897,6 @@ func getAliases(s types.ServiceConfig, c *types.ServiceNetworkConfig) []string {
return aliases return aliases
} }
func getMode(ctx context.Context, serviceName string, mode string) (string, error) {
cState, err := GetContextContainerState(ctx)
if err != nil {
return "", nil
}
observedState := cState.GetContainers()
depService := getDependentServiceFromMode(mode)
if depService != "" {
depServiceContainers := observedState.filter(isService(depService))
if len(depServiceContainers) > 0 {
return types.NetworkModeContainerPrefix + depServiceContainers[0].ID, nil
}
return "", fmt.Errorf(`no containers started for %q in service %q -> %v`,
mode, serviceName, observedState)
}
return mode, nil
}
func getNetworksForService(s types.ServiceConfig) map[string]*types.ServiceNetworkConfig { func getNetworksForService(s types.ServiceConfig) map[string]*types.ServiceNetworkConfig {
if len(s.Networks) > 0 { if len(s.Networks) > 0 {
return s.Networks return s.Networks

View File

@ -63,16 +63,16 @@ var (
) )
// InDependencyOrder applies the function to the services of the project taking in account the dependency order // InDependencyOrder applies the function to the services of the project taking in account the dependency order
func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { func InDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
return visit(ctx, project, upDirectionTraversalConfig, fn, ServiceStopped) return visit(ctx, project, upDirectionTraversalConfig, fn, ServiceStopped)
} }
// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies // InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, types.ServiceConfig) error) error { func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
return visit(ctx, project, downDirectionTraversalConfig, fn, ServiceStarted) return visit(ctx, project, downDirectionTraversalConfig, fn, ServiceStarted)
} }
func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error, initialStatus ServiceStatus) error { func visit(ctx context.Context, project *types.Project, traversalConfig graphTraversalConfig, fn func(context.Context, string) error, initialStatus ServiceStatus) error {
g := NewGraph(project.Services, initialStatus) g := NewGraph(project.Services, initialStatus)
if b, err := g.HasCycles(); b { if b, err := g.HasCycles(); b {
return err return err
@ -89,12 +89,12 @@ func visit(ctx context.Context, project *types.Project, traversalConfig graphTra
} }
// Note: this could be `graph.walk` or whatever // Note: this could be `graph.walk` or whatever
func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, types.ServiceConfig) error) error { func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex, traversalConfig graphTraversalConfig, fn func(context.Context, string) error) error {
for _, node := range nodes { for _, node := range nodes {
n := node n := node
// Don't start this service yet if all of its children have // Don't start this service yet if all of its children have
// not been started yet. // not been started yet.
if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service.Name, traversalConfig.adjacentServiceStatusToSkip)) != 0 { if len(traversalConfig.filterAdjacentByStatusFn(graph, n.Service, traversalConfig.adjacentServiceStatusToSkip)) != 0 {
continue continue
} }
@ -104,7 +104,7 @@ func run(ctx context.Context, graph *Graph, eg *errgroup.Group, nodes []*Vertex,
return err return err
} }
graph.UpdateStatus(n.Service.Name, traversalConfig.targetServiceStatus) graph.UpdateStatus(n.Service, traversalConfig.targetServiceStatus)
return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(n), traversalConfig, fn) return run(ctx, graph, eg, traversalConfig.adjacentNodesFn(n), traversalConfig, fn)
}) })
@ -122,7 +122,7 @@ type Graph struct {
// Vertex represents a service in the dependencies structure // Vertex represents a service in the dependencies structure
type Vertex struct { type Vertex struct {
Key string Key string
Service types.ServiceConfig Service string
Status ServiceStatus Status ServiceStatus
Children map[string]*Vertex Children map[string]*Vertex
Parents map[string]*Vertex Parents map[string]*Vertex
@ -162,7 +162,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph {
} }
for _, s := range services { for _, s := range services {
graph.AddVertex(s.Name, s, initialStatus) graph.AddVertex(s.Name, s.Name, initialStatus)
} }
for _, s := range services { for _, s := range services {
@ -175,7 +175,7 @@ func NewGraph(services types.Services, initialStatus ServiceStatus) *Graph {
} }
// NewVertex is the constructor function for the Vertex // NewVertex is the constructor function for the Vertex
func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) *Vertex { func NewVertex(key string, service string, initialStatus ServiceStatus) *Vertex {
return &Vertex{ return &Vertex{
Key: key, Key: key,
Service: service, Service: service,
@ -186,7 +186,7 @@ func NewVertex(key string, service types.ServiceConfig, initialStatus ServiceSta
} }
// AddVertex adds a vertex to the Graph // AddVertex adds a vertex to the Graph
func (g *Graph) AddVertex(key string, service types.ServiceConfig, initialStatus ServiceStatus) { func (g *Graph) AddVertex(key string, service string, initialStatus ServiceStatus) {
g.lock.Lock() g.lock.Lock()
defer g.lock.Unlock() defer g.lock.Unlock()

View File

@ -47,8 +47,8 @@ var project = types.Project{
func TestInDependencyUpCommandOrder(t *testing.T) { func TestInDependencyUpCommandOrder(t *testing.T) {
order := make(chan string) order := make(chan string)
//nolint:errcheck, unparam //nolint:errcheck, unparam
go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error { go InDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error {
order <- config.Name order <- config
return nil return nil
}) })
assert.Equal(t, <-order, "test3") assert.Equal(t, <-order, "test3")
@ -59,8 +59,8 @@ func TestInDependencyUpCommandOrder(t *testing.T) {
func TestInDependencyReverseDownCommandOrder(t *testing.T) { func TestInDependencyReverseDownCommandOrder(t *testing.T) {
order := make(chan string) order := make(chan string)
//nolint:errcheck, unparam //nolint:errcheck, unparam
go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config types.ServiceConfig) error { go InReverseDependencyOrder(context.TODO(), &project, func(ctx context.Context, config string) error {
order <- config.Name order <- config
return nil return nil
}) })
assert.Equal(t, <-order, "test1") assert.Equal(t, <-order, "test1")

View File

@ -50,7 +50,6 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
if err != nil { if err != nil {
return err return err
} }
ctx = context.WithValue(ctx, ContainersKey{}, NewContainersState(containers))
if options.Project == nil { if options.Project == nil {
project, err := s.projectFromContainerLabels(containers, projectName) project, err := s.projectFromContainerLabels(containers, projectName)
@ -64,8 +63,8 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
resourceToRemove = true resourceToRemove = true
} }
err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service types.ServiceConfig) error { err = InReverseDependencyOrder(ctx, options.Project, func(c context.Context, service string) error {
serviceContainers := containers.filter(isService(service.Name)) serviceContainers := containers.filter(isService(service))
err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes) err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes)
return err return err
}) })
@ -236,11 +235,6 @@ func (s *composeService) removeContainers(ctx context.Context, w progress.Writer
w.Event(progress.ErrorMessageEvent(eventName, "Error while Removing")) w.Event(progress.ErrorMessageEvent(eventName, "Error while Removing"))
return err return err
} }
contextContainerState, err := GetContextContainerState(ctx)
if err != nil {
return err
}
contextContainerState.Remove(toDelete.ID)
w.Event(progress.RemovedEvent(eventName)) w.Event(progress.RemovedEvent(eventName))
return nil return nil
}) })

View File

@ -31,10 +31,6 @@ func serviceFilter(serviceName string) filters.KeyValuePair {
return filters.Arg("label", fmt.Sprintf("%s=%s", api.ServiceLabel, serviceName)) return filters.Arg("label", fmt.Sprintf("%s=%s", api.ServiceLabel, serviceName))
} }
func slugFilter(slug string) filters.KeyValuePair {
return filters.Arg("label", fmt.Sprintf("%s=%s", api.SlugLabel, slug))
}
func oneOffFilter(b bool) filters.KeyValuePair { func oneOffFilter(b bool) filters.KeyValuePair {
v := "False" v := "False"
if b { if b {

View File

@ -21,6 +21,7 @@ import (
"github.com/compose-spec/compose-go/types" "github.com/compose-spec/compose-go/types"
"github.com/docker/compose-cli/pkg/api" "github.com/docker/compose-cli/pkg/api"
"golang.org/x/sync/errgroup"
"github.com/docker/compose-cli/pkg/progress" "github.com/docker/compose-cli/pkg/progress"
"github.com/docker/compose-cli/pkg/utils" "github.com/docker/compose-cli/pkg/utils"
@ -33,7 +34,7 @@ func (s *composeService) Restart(ctx context.Context, project *types.Project, op
} }
func (s *composeService) restart(ctx context.Context, project *types.Project, options api.RestartOptions) error { func (s *composeService) restart(ctx context.Context, project *types.Project, options api.RestartOptions) error {
ctx, err := s.getUpdatedContainersStateContext(ctx, project.Name) observedState, err := s.getContainers(ctx, project.Name, oneOffInclude, true)
if err != nil { if err != nil {
return err return err
} }
@ -42,11 +43,25 @@ func (s *composeService) restart(ctx context.Context, project *types.Project, op
options.Services = project.ServiceNames() options.Services = project.ServiceNames()
} }
err = InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { w := progress.ContextWriter(ctx)
if utils.StringContains(options.Services, service.Name) { err = InDependencyOrder(ctx, project, func(c context.Context, service string) error {
return s.restartService(ctx, service.Name, options.Timeout) if !utils.StringContains(options.Services, service) {
return nil
} }
return nil eg, ctx := errgroup.WithContext(ctx)
for _, c := range observedState.filter(isService(service)) {
container := c
eg.Go(func() error {
eventName := getContainerProgressName(container)
w.Event(progress.RestartingEvent(eventName))
err := s.apiClient.ContainerRestart(ctx, container.ID, options.Timeout)
if err == nil {
w.Event(progress.StartedEvent(eventName))
}
return err
})
}
return eg.Wait()
}) })
if err != nil { if err != nil {
return err return err

View File

@ -25,7 +25,6 @@ import (
"github.com/compose-spec/compose-go/types" "github.com/compose-spec/compose-go/types"
moby "github.com/docker/docker/api/types" moby "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/pkg/stringid" "github.com/docker/docker/pkg/stringid"
) )
@ -34,8 +33,6 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
if err != nil { if err != nil {
return 0, err return 0, err
} }
containerState := NewContainersState(observedState)
ctx = context.WithValue(ctx, ContainersKey{}, containerState)
service, err := project.GetService(opts.Service) service, err := project.GetService(opts.Service)
if err != nil { if err != nil {
@ -63,10 +60,11 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
if err := s.waitDependencies(ctx, project, service); err != nil { if err := s.waitDependencies(ctx, project, service); err != nil {
return 0, err return 0, err
} }
if err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases); err != nil { created, err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.AutoRemove, opts.UseNetworkAliases)
if err != nil {
return 0, err return 0, err
} }
containerID := service.ContainerName containerID := created.ID
if opts.Detach { if opts.Detach {
err := s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{}) err := s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
@ -77,21 +75,13 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.
return 0, nil return 0, nil
} }
containers, err := s.apiClient.ContainerList(ctx, moby.ContainerListOptions{ restore, err := s.attachContainerStreams(ctx, containerID, service.Tty, opts.Reader, opts.Writer)
Filters: filters.NewArgs(slugFilter(slug)),
All: true,
})
if err != nil {
return 0, err
}
oneoffContainer := containers[0]
restore, err := s.attachContainerStreams(ctx, oneoffContainer.ID, service.Tty, opts.Reader, opts.Writer)
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer restore() defer restore()
statusC, errC := s.apiClient.ContainerWait(context.Background(), oneoffContainer.ID, container.WaitConditionNextExit) statusC, errC := s.apiClient.ContainerWait(context.Background(), containerID, container.WaitConditionNextExit)
err = s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{}) err = s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
if err != nil { if err != nil {

View File

@ -53,7 +53,11 @@ func (s *composeService) start(ctx context.Context, project *types.Project, opti
}) })
} }
err := InDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { err := InDependencyOrder(ctx, project, func(c context.Context, name string) error {
service, err := project.GetService(name)
if err != nil {
return err
}
return s.startService(ctx, project, service) return s.startService(ctx, project, service)
}) })
if err != nil { if err != nil {

View File

@ -1,111 +0,0 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package compose
import (
"context"
"github.com/docker/docker/api/types"
"github.com/pkg/errors"
)
// ContainersKey is the context key to access context value os a ContainersStatus
type ContainersKey struct{}
// ContainersState state management interface
type ContainersState interface {
Get(string) *types.Container
GetContainers() Containers
Add(c types.Container)
AddAll(cs Containers)
Remove(string) types.Container
}
// NewContainersState creates a new container state manager
func NewContainersState(cs Containers) ContainersState {
s := containersState{
observedContainers: &cs,
}
return &s
}
// ContainersStatus works as a collection container for the observed containers
type containersState struct {
observedContainers *Containers
}
func (s *containersState) AddAll(cs Containers) {
for _, c := range cs {
lValue := append(*s.observedContainers, c)
s.observedContainers = &lValue
}
}
func (s *containersState) Add(c types.Container) {
if s.Get(c.ID) == nil {
lValue := append(*s.observedContainers, c)
s.observedContainers = &lValue
}
}
func (s *containersState) Remove(id string) types.Container {
var c types.Container
var newObserved Containers
for _, o := range *s.observedContainers {
if o.ID == id {
c = o
continue
}
newObserved = append(newObserved, o)
}
s.observedContainers = &newObserved
return c
}
func (s *containersState) Get(id string) *types.Container {
for _, o := range *s.observedContainers {
if id == o.ID {
return &o
}
}
return nil
}
func (s *containersState) GetContainers() Containers {
if s.observedContainers != nil && *s.observedContainers != nil {
return *s.observedContainers
}
return make(Containers, 0)
}
// GetContextContainerState gets the container state manager
func GetContextContainerState(ctx context.Context) (ContainersState, error) {
cState, ok := ctx.Value(ContainersKey{}).(*containersState)
if !ok {
return nil, errors.New("containers' containersState not available in context")
}
return cState, nil
}
func (s composeService) getUpdatedContainersStateContext(ctx context.Context, projectName string) (context.Context, error) {
observedState, err := s.getContainers(ctx, projectName, oneOffInclude, true)
if err != nil {
return nil, err
}
containerState := NewContainersState(observedState)
return context.WithValue(ctx, ContainersKey{}, containerState), nil
}

View File

@ -44,7 +44,7 @@ func (s *composeService) stop(ctx context.Context, project *types.Project, optio
return err return err
} }
return InReverseDependencyOrder(ctx, project, func(c context.Context, service types.ServiceConfig) error { return InReverseDependencyOrder(ctx, project, func(c context.Context, service string) error {
return s.stopContainers(ctx, w, containers.filter(isService(service.Name)), options.Timeout) return s.stopContainers(ctx, w, containers.filter(isService(service)), options.Timeout)
}) })
} }