diff --git a/cmd/compose/compose.go b/cmd/compose/compose.go index 7dd593e7b..f91b0ec82 100644 --- a/cmd/compose/compose.go +++ b/cmd/compose/compose.go @@ -83,6 +83,8 @@ func AdaptCmd(fn CobraCommand) func(cmd *cobra.Command, args []string) error { go func() { <-s cancel() + signal.Stop(s) + close(s) }() } err := fn(ctx, cmd, args) diff --git a/pkg/compose/up.go b/pkg/compose/up.go index 64a0743c8..4b3428376 100644 --- a/pkg/compose/up.go +++ b/pkg/compose/up.go @@ -21,7 +21,6 @@ import ( "fmt" "os" "os/signal" - "sync" "syscall" "github.com/compose-spec/compose-go/types" @@ -55,76 +54,79 @@ func (s *composeService) Up(ctx context.Context, project *types.Project, options return err } + var eg multierror.Group + // if we get a second signal during shutdown, we kill the services // immediately, so the channel needs to have sufficient capacity or // we might miss a signal while setting up the second channel read // (this is also why signal.Notify is used vs signal.NotifyContext) signalChan := make(chan os.Signal, 2) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) - signalCancel := sync.OnceFunc(func() { - signal.Stop(signalChan) - close(signalChan) + defer close(signalChan) + var isTerminated bool + + doneCh := make(chan bool) + eg.Go(func() error { + first := true + for { + select { + case <-doneCh: + return nil + case <-signalChan: + if first { + fmt.Fprintln(s.stdinfo(), "Gracefully stopping... (press Ctrl+C again to force)") + eg.Go(func() error { + err := s.Stop(context.Background(), project.Name, api.StopOptions{ + Services: options.Create.Services, + Project: project, + }) + isTerminated = true + close(doneCh) + return err + }) + first = false + } else { + eg.Go(func() error { + return s.Kill(context.Background(), project.Name, api.KillOptions{ + Services: options.Create.Services, + Project: project, + }) + }) + return nil + } + } + } }) - defer signalCancel() printer := newLogPrinter(options.Start.Attach) - stopFunc := func() error { - fmt.Fprintln(s.stdinfo(), "Aborting on container exit...") - ctx := context.Background() - return progress.Run(ctx, func(ctx context.Context) error { - // race two goroutines - one that blocks until another signal is received - // and then does a Kill() and one that immediately starts a friendly Stop() - errCh := make(chan error, 1) - go func() { - if _, ok := <-signalChan; !ok { - // channel closed, so the outer function is done, which - // means the other goroutine (calling Stop()) finished - return - } - errCh <- s.Kill(ctx, project.Name, api.KillOptions{ - Services: options.Create.Services, - Project: project, - }) - }() - - go func() { - errCh <- s.Stop(ctx, project.Name, api.StopOptions{ - Services: options.Create.Services, - Project: project, - }) - }() - return <-errCh - }, s.stdinfo()) - } - - var isTerminated bool - var eg multierror.Group - eg.Go(func() error { - if _, ok := <-signalChan; !ok { - // function finished without receiving a signal - return nil - } - isTerminated = true - printer.Cancel() - fmt.Fprintln(s.stdinfo(), "Gracefully stopping... (press Ctrl+C again to force)") - return stopFunc() - }) var exitCode int eg.Go(func() error { - code, err := printer.Run(options.Start.CascadeStop, options.Start.ExitCodeFrom, stopFunc) + code, err := printer.Run(options.Start.CascadeStop, options.Start.ExitCodeFrom, func() error { + fmt.Fprintln(s.stdinfo(), "Aborting on container exit...") + return progress.Run(ctx, func(ctx context.Context) error { + return s.Stop(ctx, project.Name, api.StopOptions{ + Services: options.Create.Services, + Project: project, + }) + }, s.stdinfo()) + }) exitCode = code return err }) - err = s.start(ctx, project.Name, options.Start, printer.HandleEvent) + // We don't use parent (cancelable) context as we manage sigterm to stop the stack + err = s.start(context.Background(), project.Name, options.Start, printer.HandleEvent) if err != nil && !isTerminated { // Ignore error if the process is terminated return err } - // signal for the goroutines to stop & wait for them to finish any remaining work - signalCancel() printer.Stop() + + if !isTerminated { + // signal for the signal-handler goroutines to stop + close(doneCh) + } err = eg.Wait().ErrorOrNil() if exitCode != 0 { errMsg := ""