diff --git a/local/compose/run.go b/local/compose/run.go index 2c7b1e6d1..804174693 100644 --- a/local/compose/run.go +++ b/local/compose/run.go @@ -19,13 +19,15 @@ package compose import ( "context" "fmt" - "io" "os" "github.com/compose-spec/compose-go/types" "github.com/docker/compose-cli/api/compose" - convert "github.com/docker/compose-cli/local/moby" + "github.com/docker/compose-cli/utils" apitypes "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "golang.org/x/sync/errgroup" + moby "github.com/docker/docker/pkg/stringid" ) @@ -65,52 +67,33 @@ func (s *composeService) RunOneOffContainer(ctx context.Context, project *types. return containerID, s.apiClient.ContainerStart(ctx, containerID, apitypes.ContainerStartOptions{}) } - cnx, err := s.apiClient.ContainerAttach(ctx, containerID, apitypes.ContainerAttachOptions{ - Stream: true, - Stdin: true, - Stdout: true, - Stderr: true, - Logs: true, + containers, err := s.apiClient.ContainerList(ctx, apitypes.ContainerListOptions{ + Filters: filters.NewArgs( + projectFilter(project.Name), + ), + All: true, }) if err != nil { - return containerID, err + return "", err } - defer cnx.Close() - - stdout := convert.ContainerStdout{HijackedResponse: cnx} - stdin := convert.ContainerStdin{HijackedResponse: cnx} - - readChannel := make(chan error, 10) - writeChannel := make(chan error, 10) - - go func() { - _, err := io.Copy(os.Stdout, cnx.Reader) - readChannel <- err - }() - - go func() { - _, err := io.Copy(stdin, os.Stdin) - writeChannel <- err - }() - - go func() { - <-ctx.Done() - stdout.Close() //nolint:errcheck - stdin.Close() //nolint:errcheck - }() - - // start container - err = s.apiClient.ContainerStart(ctx, containerID, apitypes.ContainerStartOptions{}) - if err != nil { - return containerID, err - } - - for { - select { - case err := <-readChannel: - return containerID, err - case err := <-writeChannel: - return containerID, err + var oneoffContainer apitypes.Container + for _, container := range containers { + if utils.StringContains(container.Names, "/"+containerID) { + oneoffContainer = container } } + eg := errgroup.Group{} + eg.Go(func() error { + return s.attachContainerStreams(ctx, oneoffContainer, true, os.Stdin, os.Stdout) + }) + if err != nil { + return "", err + } + + err = s.apiClient.ContainerStart(ctx, containerID, apitypes.ContainerStartOptions{}) + if err != nil { + return "", err + } + err = eg.Wait() + return containerID, err }