diff --git a/pkg/compose/model.go b/pkg/compose/model.go index 9922028e0..68e1d425c 100644 --- a/pkg/compose/model.go +++ b/pkg/compose/model.go @@ -75,7 +75,7 @@ func (s *composeService) ensureModels(ctx context.Context, project *types.Projec eg, gctx := errgroup.WithContext(ctx) eg.Go(func() error { - return s.setModelEndpointVariable(gctx, dockerModel, project) + return s.setModelVariables(gctx, dockerModel, project) }) for name, config := range project.Models { @@ -161,7 +161,7 @@ func (s *composeService) configureModel(ctx context.Context, dockerModel *manage return cmd.Run() } -func (s *composeService) setModelEndpointVariable(ctx context.Context, dockerModel *manager.Plugin, project *types.Project) error { +func (s *composeService) setModelVariables(ctx context.Context, dockerModel *manager.Plugin, project *types.Project) error { cmd := exec.CommandContext(ctx, dockerModel.Path, "status", "--json") s.setupChildProcess(ctx, cmd) statusOut, err := cmd.CombinedOutput() @@ -179,12 +179,21 @@ func (s *composeService) setModelEndpointVariable(ctx context.Context, dockerMod } for _, service := range project.Services { - for model, modelConfig := range service.Models { + for ref, modelConfig := range service.Models { + model := project.Models[ref] + varPrefix := strings.ReplaceAll(strings.ToUpper(ref), "-", "_") var variable string + if modelConfig != nil && modelConfig.ModelVariable != "" { + variable = modelConfig.ModelVariable + } else { + variable = varPrefix + } + service.Environment[variable] = &model.Model + if modelConfig != nil && modelConfig.EndpointVariable != "" { variable = modelConfig.EndpointVariable } else { - variable = strings.ToUpper(model) + "_URL" + variable = varPrefix + "_URL" } service.Environment[variable] = &status.Endpoint }