diff --git a/pkg/compose/plugins.go b/pkg/compose/plugins.go index a905546fe..315f464cd 100644 --- a/pkg/compose/plugins.go +++ b/pkg/compose/plugins.go @@ -61,15 +61,33 @@ func (s *composeService) runPlugin(ctx context.Context, project *types.Project, cmd := s.setupPluginCommand(ctx, project, provider, plugin.Path, command) - eg := errgroup.Group{} - stdout, err := cmd.StdoutPipe() + variables, err := s.executePlugin(ctx, cmd, command, service) if err != nil { return err } + for name, s := range project.Services { + if _, ok := s.DependsOn[service.Name]; ok { + prefix := strings.ToUpper(service.Name) + "_" + for key, val := range variables { + s.Environment[prefix+key] = &val + } + project.Services[name] = s + } + } + return nil +} + +func (s *composeService) executePlugin(ctx context.Context, cmd *exec.Cmd, command string, service types.ServiceConfig) (types.Mapping, error) { + eg := errgroup.Group{} + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + err = cmd.Start() if err != nil { - return err + return nil, err } eg.Go(cmd.Wait) @@ -79,7 +97,17 @@ func (s *composeService) runPlugin(ctx context.Context, project *types.Project, variables := types.Mapping{} pw := progress.ContextWriter(ctx) - pw.Event(progress.CreatingEvent(service.Name)) + var action string + switch command { + case "up": + pw.Event(progress.CreatingEvent(service.Name)) + action = "create" + case "down": + pw.Event(progress.RemovingEvent(service.Name)) + action = "remove" + default: + return nil, fmt.Errorf("unsupported plugin command: %s", command) + } for { var msg JsonMessage err = decoder.Decode(&msg) @@ -87,42 +115,37 @@ func (s *composeService) runPlugin(ctx context.Context, project *types.Project, break } if err != nil { - return err + return nil, err } switch msg.Type { case ErrorType: pw.Event(progress.ErrorMessageEvent(service.Name, "error")) - return errors.New(msg.Message) + return nil, errors.New(msg.Message) case InfoType: pw.Event(progress.ErrorMessageEvent(service.Name, msg.Message)) case SetEnvType: key, val, found := strings.Cut(msg.Message, "=") if !found { - return fmt.Errorf("invalid response from plugin: %s", msg.Message) + return nil, fmt.Errorf("invalid response from plugin: %s", msg.Message) } variables[key] = val default: - return fmt.Errorf("invalid response from plugin: %s", msg.Type) + return nil, fmt.Errorf("invalid response from plugin: %s", msg.Type) } } err = eg.Wait() if err != nil { pw.Event(progress.ErrorMessageEvent(service.Name, err.Error())) - return fmt.Errorf("failed to create external service: %s", err.Error()) + return nil, fmt.Errorf("failed to %s external service: %s", action, err.Error()) } - pw.Event(progress.CreatedEvent(service.Name)) - - prefix := strings.ToUpper(service.Name) + "_" - for name, s := range project.Services { - if _, ok := s.DependsOn[service.Name]; ok { - for key, val := range variables { - s.Environment[prefix+key] = &val - } - project.Services[name] = s - } + switch command { + case "up": + pw.Event(progress.CreatedEvent(service.Name)) + case "down": + pw.Event(progress.RemovedEvent(service.Name)) } - return nil + return variables, nil } func (s *composeService) getPluginBinaryPath(providerType string) (*manager.Plugin, error) {