Merge pull request #764 from docker/test_with_mock

More unit tests with aws-sdk behind an interface + mocks
This commit is contained in:
Nicolas De loof 2020-10-12 21:21:38 +02:00 committed by GitHub
commit 45bb05ee19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 923 additions and 226 deletions

View File

@ -30,7 +30,7 @@ services:
image: hello_world
deploy:
x-aws-autoscaling: 75
`)
`, useDefaultVPC)
target := template.Resources["FooScalableTarget"].(*autoscaling.ScalableTarget)
assert.Check(t, target != nil)
policy := template.Resources["FooScalingPolicy"].(*autoscaling.ScalingPolicy)

View File

@ -16,7 +16,61 @@
package ecs
import (
"context"
"github.com/aws/aws-sdk-go/service/cloudformation"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/docker/compose-cli/api/compose"
"github.com/docker/compose-cli/api/secrets"
)
const (
awsTypeCapacityProvider = "AWS::ECS::CapacityProvider"
awsTypeAutoscalingGroup = "AWS::AutoScaling::AutoScalingGroup"
)
//go:generate mockgen -destination=./aws_mock.go -self_package "github.com/docker/compose-cli/ecs" -package=ecs . API
// API hides aws-go-sdk into a simpler, focussed API subset
type API interface {
CheckRequirements(ctx context.Context, region string) error
ClusterExists(ctx context.Context, name string) (bool, error)
CreateCluster(ctx context.Context, name string) (string, error)
CheckVPC(ctx context.Context, vpcID string) error
GetDefaultVPC(ctx context.Context) (string, error)
GetSubNets(ctx context.Context, vpcID string) ([]string, error)
GetRoleArn(ctx context.Context, name string) (string, error)
StackExists(ctx context.Context, name string) (bool, error)
CreateStack(ctx context.Context, name string, template []byte) error
CreateChangeSet(ctx context.Context, name string, template []byte) (string, error)
UpdateStack(ctx context.Context, changeset string) error
WaitStackComplete(ctx context.Context, name string, operation int) error
GetStackID(ctx context.Context, name string) (string, error)
ListStacks(ctx context.Context, name string) ([]compose.Stack, error)
GetStackClusterID(ctx context.Context, stack string) (string, error)
GetServiceTaskDefinition(ctx context.Context, cluster string, serviceArns []string) (map[string]string, error)
ListStackServices(ctx context.Context, stack string) ([]string, error)
GetServiceTasks(ctx context.Context, cluster string, service string, stopped bool) ([]*ecs.Task, error)
GetTaskStoppedReason(ctx context.Context, cluster string, taskArn string) (string, error)
DescribeStackEvents(ctx context.Context, stackID string) ([]*cloudformation.StackEvent, error)
ListStackParameters(ctx context.Context, name string) (map[string]string, error)
ListStackResources(ctx context.Context, name string) (stackResources, error)
DeleteStack(ctx context.Context, name string) error
CreateSecret(ctx context.Context, secret secrets.Secret) (string, error)
InspectSecret(ctx context.Context, id string) (secrets.Secret, error)
ListSecrets(ctx context.Context) ([]secrets.Secret, error)
DeleteSecret(ctx context.Context, id string, recover bool) error
GetLogs(ctx context.Context, name string, consumer func(service, container, message string)) error
DescribeService(ctx context.Context, cluster string, arn string) (compose.ServiceStatus, error)
getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
LoadBalancerType(ctx context.Context, arn string) (string, error)
GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
WithVolumeSecurityGroups(ctx context.Context, id string, fn func(securityGroups []string) error) error
GetParameter(ctx context.Context, name string) (string, error)
SecurityGroupExists(ctx context.Context, sg string) (bool, error)
DeleteCapacityProvider(ctx context.Context, arn string) error
DeleteAutoscalingGroup(ctx context.Context, arn string) error
}

View File

@ -82,7 +82,7 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project) (awsR
func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *types.Project) (string, error) {
if x, ok := project.Extensions[extensionCluster]; ok {
cluster := x.(string)
ok, err := b.SDK.ClusterExists(ctx, cluster)
ok, err := b.aws.ClusterExists(ctx, cluster)
if err != nil {
return "", err
}
@ -98,20 +98,20 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
var vpc string
if x, ok := project.Extensions[extensionVPC]; ok {
vpc = x.(string)
err := b.SDK.CheckVPC(ctx, vpc)
err := b.aws.CheckVPC(ctx, vpc)
if err != nil {
return "", nil, err
}
} else {
defaultVPC, err := b.SDK.GetDefaultVPC(ctx)
defaultVPC, err := b.aws.GetDefaultVPC(ctx)
if err != nil {
return "", nil, err
}
vpc = defaultVPC
}
subNets, err := b.SDK.GetSubNets(ctx, vpc)
subNets, err := b.aws.GetSubNets(ctx, vpc)
if err != nil {
return "", nil, err
}
@ -124,7 +124,7 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (string, string, error) {
if x, ok := project.Extensions[extensionLoadBalancer]; ok {
loadBalancer := x.(string)
loadBalancerType, err := b.SDK.LoadBalancerType(ctx, loadBalancer)
loadBalancerType, err := b.aws.LoadBalancerType(ctx, loadBalancer)
if err != nil {
return "", "", err
}
@ -142,16 +142,16 @@ func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project
func (b *ecsAPIService) parseSecurityGroupExtension(ctx context.Context, project *types.Project) (map[string]string, error) {
securityGroups := make(map[string]string, len(project.Networks))
for name, net := range project.Networks {
var sg string
if net.External.External {
sg = net.Name
if !net.External.External {
continue
}
sg := net.Name
if x, ok := net.Extensions[extensionSecurityGroup]; ok {
logrus.Warn("to use an existing security-group, use `network.external` and `network.name` in your compose file")
logrus.Debugf("Security Group for network %q set by user to %q", net.Name, x)
sg = x.(string)
}
exists, err := b.SDK.SecurityGroupExists(ctx, sg)
exists, err := b.aws.SecurityGroupExists(ctx, sg)
if err != nil {
return nil, err
}
@ -186,6 +186,11 @@ func (b *ecsAPIService) ensureNetworks(r *awsResources, project *types.Project,
r.securityGroups = make(map[string]string, len(project.Networks))
}
for name, net := range project.Networks {
if net.External.External {
r.securityGroups[name] = net.Name
continue
}
securityGroup := networkResourceName(name)
template.Resources[securityGroup] = &ec2.SecurityGroup{
GroupDescription: fmt.Sprintf("%s Security Group for %s network", project.Name, name),

617
ecs/aws_mock.go Normal file
View File

@ -0,0 +1,617 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/docker/compose-cli/ecs (interfaces: API)
// Package ecs is a generated GoMock package.
package ecs
import (
context "context"
cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
ecs "github.com/aws/aws-sdk-go/service/ecs"
compose "github.com/docker/compose-cli/api/compose"
secrets "github.com/docker/compose-cli/api/secrets"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAPI is a mock of API interface
type MockAPI struct {
ctrl *gomock.Controller
recorder *MockAPIMockRecorder
}
// MockAPIMockRecorder is the mock recorder for MockAPI
type MockAPIMockRecorder struct {
mock *MockAPI
}
// NewMockAPI creates a new mock instance
func NewMockAPI(ctrl *gomock.Controller) *MockAPI {
mock := &MockAPI{ctrl: ctrl}
mock.recorder = &MockAPIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockAPI) EXPECT() *MockAPIMockRecorder {
return m.recorder
}
// CheckRequirements mocks base method
func (m *MockAPI) CheckRequirements(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckRequirements", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// CheckRequirements indicates an expected call of CheckRequirements
func (mr *MockAPIMockRecorder) CheckRequirements(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckRequirements", reflect.TypeOf((*MockAPI)(nil).CheckRequirements), arg0, arg1)
}
// CheckVPC mocks base method
func (m *MockAPI) CheckVPC(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CheckVPC", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// CheckVPC indicates an expected call of CheckVPC
func (mr *MockAPIMockRecorder) CheckVPC(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckVPC", reflect.TypeOf((*MockAPI)(nil).CheckVPC), arg0, arg1)
}
// ClusterExists mocks base method
func (m *MockAPI) ClusterExists(arg0 context.Context, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterExists", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ClusterExists indicates an expected call of ClusterExists
func (mr *MockAPIMockRecorder) ClusterExists(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterExists", reflect.TypeOf((*MockAPI)(nil).ClusterExists), arg0, arg1)
}
// CreateChangeSet mocks base method
func (m *MockAPI) CreateChangeSet(arg0 context.Context, arg1 string, arg2 []byte) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateChangeSet", arg0, arg1, arg2)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateChangeSet indicates an expected call of CreateChangeSet
func (mr *MockAPIMockRecorder) CreateChangeSet(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateChangeSet", reflect.TypeOf((*MockAPI)(nil).CreateChangeSet), arg0, arg1, arg2)
}
// CreateCluster mocks base method
func (m *MockAPI) CreateCluster(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateCluster", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateCluster indicates an expected call of CreateCluster
func (mr *MockAPIMockRecorder) CreateCluster(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockAPI)(nil).CreateCluster), arg0, arg1)
}
// CreateSecret mocks base method
func (m *MockAPI) CreateSecret(arg0 context.Context, arg1 secrets.Secret) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateSecret", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateSecret indicates an expected call of CreateSecret
func (mr *MockAPIMockRecorder) CreateSecret(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSecret", reflect.TypeOf((*MockAPI)(nil).CreateSecret), arg0, arg1)
}
// CreateStack mocks base method
func (m *MockAPI) CreateStack(arg0 context.Context, arg1 string, arg2 []byte) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateStack", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// CreateStack indicates an expected call of CreateStack
func (mr *MockAPIMockRecorder) CreateStack(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateStack", reflect.TypeOf((*MockAPI)(nil).CreateStack), arg0, arg1, arg2)
}
// DeleteAutoscalingGroup mocks base method
func (m *MockAPI) DeleteAutoscalingGroup(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAutoscalingGroup", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAutoscalingGroup indicates an expected call of DeleteAutoscalingGroup
func (mr *MockAPIMockRecorder) DeleteAutoscalingGroup(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAutoscalingGroup", reflect.TypeOf((*MockAPI)(nil).DeleteAutoscalingGroup), arg0, arg1)
}
// DeleteCapacityProvider mocks base method
func (m *MockAPI) DeleteCapacityProvider(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteCapacityProvider", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteCapacityProvider indicates an expected call of DeleteCapacityProvider
func (mr *MockAPIMockRecorder) DeleteCapacityProvider(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCapacityProvider", reflect.TypeOf((*MockAPI)(nil).DeleteCapacityProvider), arg0, arg1)
}
// DeleteSecret mocks base method
func (m *MockAPI) DeleteSecret(arg0 context.Context, arg1 string, arg2 bool) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteSecret", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteSecret indicates an expected call of DeleteSecret
func (mr *MockAPIMockRecorder) DeleteSecret(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSecret", reflect.TypeOf((*MockAPI)(nil).DeleteSecret), arg0, arg1, arg2)
}
// DeleteStack mocks base method
func (m *MockAPI) DeleteStack(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteStack", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteStack indicates an expected call of DeleteStack
func (mr *MockAPIMockRecorder) DeleteStack(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStack", reflect.TypeOf((*MockAPI)(nil).DeleteStack), arg0, arg1)
}
// DescribeService mocks base method
func (m *MockAPI) DescribeService(arg0 context.Context, arg1, arg2 string) (compose.ServiceStatus, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DescribeService", arg0, arg1, arg2)
ret0, _ := ret[0].(compose.ServiceStatus)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DescribeService indicates an expected call of DescribeService
func (mr *MockAPIMockRecorder) DescribeService(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeService", reflect.TypeOf((*MockAPI)(nil).DescribeService), arg0, arg1, arg2)
}
// DescribeStackEvents mocks base method
func (m *MockAPI) DescribeStackEvents(arg0 context.Context, arg1 string) ([]*cloudformation.StackEvent, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DescribeStackEvents", arg0, arg1)
ret0, _ := ret[0].([]*cloudformation.StackEvent)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DescribeStackEvents indicates an expected call of DescribeStackEvents
func (mr *MockAPIMockRecorder) DescribeStackEvents(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeStackEvents", reflect.TypeOf((*MockAPI)(nil).DescribeStackEvents), arg0, arg1)
}
// GetDefaultVPC mocks base method
func (m *MockAPI) GetDefaultVPC(arg0 context.Context) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDefaultVPC", arg0)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetDefaultVPC indicates an expected call of GetDefaultVPC
func (mr *MockAPIMockRecorder) GetDefaultVPC(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultVPC", reflect.TypeOf((*MockAPI)(nil).GetDefaultVPC), arg0)
}
// GetLoadBalancerURL mocks base method
func (m *MockAPI) GetLoadBalancerURL(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetLoadBalancerURL", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetLoadBalancerURL indicates an expected call of GetLoadBalancerURL
func (mr *MockAPIMockRecorder) GetLoadBalancerURL(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLoadBalancerURL", reflect.TypeOf((*MockAPI)(nil).GetLoadBalancerURL), arg0, arg1)
}
// GetLogs mocks base method
func (m *MockAPI) GetLogs(arg0 context.Context, arg1 string, arg2 func(string, string, string)) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetLogs", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// GetLogs indicates an expected call of GetLogs
func (mr *MockAPIMockRecorder) GetLogs(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockAPI)(nil).GetLogs), arg0, arg1, arg2)
}
// GetParameter mocks base method
func (m *MockAPI) GetParameter(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetParameter", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetParameter indicates an expected call of GetParameter
func (mr *MockAPIMockRecorder) GetParameter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParameter", reflect.TypeOf((*MockAPI)(nil).GetParameter), arg0, arg1)
}
// GetPublicIPs mocks base method
func (m *MockAPI) GetPublicIPs(arg0 context.Context, arg1 ...string) (map[string]string, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GetPublicIPs", varargs...)
ret0, _ := ret[0].(map[string]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPublicIPs indicates an expected call of GetPublicIPs
func (mr *MockAPIMockRecorder) GetPublicIPs(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicIPs", reflect.TypeOf((*MockAPI)(nil).GetPublicIPs), varargs...)
}
// GetRoleArn mocks base method
func (m *MockAPI) GetRoleArn(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoleArn", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoleArn indicates an expected call of GetRoleArn
func (mr *MockAPIMockRecorder) GetRoleArn(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleArn", reflect.TypeOf((*MockAPI)(nil).GetRoleArn), arg0, arg1)
}
// GetServiceTaskDefinition mocks base method
func (m *MockAPI) GetServiceTaskDefinition(arg0 context.Context, arg1 string, arg2 []string) (map[string]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTaskDefinition", arg0, arg1, arg2)
ret0, _ := ret[0].(map[string]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceTaskDefinition indicates an expected call of GetServiceTaskDefinition
func (mr *MockAPIMockRecorder) GetServiceTaskDefinition(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTaskDefinition", reflect.TypeOf((*MockAPI)(nil).GetServiceTaskDefinition), arg0, arg1, arg2)
}
// GetServiceTasks mocks base method
func (m *MockAPI) GetServiceTasks(arg0 context.Context, arg1, arg2 string, arg3 bool) ([]*ecs.Task, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTasks", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]*ecs.Task)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceTasks indicates an expected call of GetServiceTasks
func (mr *MockAPIMockRecorder) GetServiceTasks(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTasks", reflect.TypeOf((*MockAPI)(nil).GetServiceTasks), arg0, arg1, arg2, arg3)
}
// GetStackClusterID mocks base method
func (m *MockAPI) GetStackClusterID(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStackClusterID", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStackClusterID indicates an expected call of GetStackClusterID
func (mr *MockAPIMockRecorder) GetStackClusterID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStackClusterID", reflect.TypeOf((*MockAPI)(nil).GetStackClusterID), arg0, arg1)
}
// GetStackID mocks base method
func (m *MockAPI) GetStackID(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStackID", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStackID indicates an expected call of GetStackID
func (mr *MockAPIMockRecorder) GetStackID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStackID", reflect.TypeOf((*MockAPI)(nil).GetStackID), arg0, arg1)
}
// GetSubNets mocks base method
func (m *MockAPI) GetSubNets(arg0 context.Context, arg1 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSubNets", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSubNets indicates an expected call of GetSubNets
func (mr *MockAPIMockRecorder) GetSubNets(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubNets", reflect.TypeOf((*MockAPI)(nil).GetSubNets), arg0, arg1)
}
// GetTaskStoppedReason mocks base method
func (m *MockAPI) GetTaskStoppedReason(arg0 context.Context, arg1, arg2 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTaskStoppedReason", arg0, arg1, arg2)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTaskStoppedReason indicates an expected call of GetTaskStoppedReason
func (mr *MockAPIMockRecorder) GetTaskStoppedReason(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskStoppedReason", reflect.TypeOf((*MockAPI)(nil).GetTaskStoppedReason), arg0, arg1, arg2)
}
// InspectSecret mocks base method
func (m *MockAPI) InspectSecret(arg0 context.Context, arg1 string) (secrets.Secret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InspectSecret", arg0, arg1)
ret0, _ := ret[0].(secrets.Secret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InspectSecret indicates an expected call of InspectSecret
func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
}
// ListSecrets mocks base method
func (m *MockAPI) ListSecrets(arg0 context.Context) ([]secrets.Secret, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListSecrets", arg0)
ret0, _ := ret[0].([]secrets.Secret)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListSecrets indicates an expected call of ListSecrets
func (mr *MockAPIMockRecorder) ListSecrets(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSecrets", reflect.TypeOf((*MockAPI)(nil).ListSecrets), arg0)
}
// ListStackParameters mocks base method
func (m *MockAPI) ListStackParameters(arg0 context.Context, arg1 string) (map[string]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListStackParameters", arg0, arg1)
ret0, _ := ret[0].(map[string]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListStackParameters indicates an expected call of ListStackParameters
func (mr *MockAPIMockRecorder) ListStackParameters(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackParameters", reflect.TypeOf((*MockAPI)(nil).ListStackParameters), arg0, arg1)
}
// ListStackResources mocks base method
func (m *MockAPI) ListStackResources(arg0 context.Context, arg1 string) (stackResources, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListStackResources", arg0, arg1)
ret0, _ := ret[0].(stackResources)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListStackResources indicates an expected call of ListStackResources
func (mr *MockAPIMockRecorder) ListStackResources(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackResources", reflect.TypeOf((*MockAPI)(nil).ListStackResources), arg0, arg1)
}
// ListStackServices mocks base method
func (m *MockAPI) ListStackServices(arg0 context.Context, arg1 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListStackServices", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListStackServices indicates an expected call of ListStackServices
func (mr *MockAPIMockRecorder) ListStackServices(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStackServices", reflect.TypeOf((*MockAPI)(nil).ListStackServices), arg0, arg1)
}
// ListStacks mocks base method
func (m *MockAPI) ListStacks(arg0 context.Context, arg1 string) ([]compose.Stack, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListStacks", arg0, arg1)
ret0, _ := ret[0].([]compose.Stack)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListStacks indicates an expected call of ListStacks
func (mr *MockAPIMockRecorder) ListStacks(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListStacks", reflect.TypeOf((*MockAPI)(nil).ListStacks), arg0, arg1)
}
// ListTasks mocks base method
func (m *MockAPI) ListTasks(arg0 context.Context, arg1, arg2 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListTasks", arg0, arg1, arg2)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListTasks indicates an expected call of ListTasks
func (mr *MockAPIMockRecorder) ListTasks(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockAPI)(nil).ListTasks), arg0, arg1, arg2)
}
// LoadBalancerType mocks base method
func (m *MockAPI) LoadBalancerType(arg0 context.Context, arg1 string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoadBalancerType", arg0, arg1)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LoadBalancerType indicates an expected call of LoadBalancerType
func (mr *MockAPIMockRecorder) LoadBalancerType(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerType", reflect.TypeOf((*MockAPI)(nil).LoadBalancerType), arg0, arg1)
}
// SecurityGroupExists mocks base method
func (m *MockAPI) SecurityGroupExists(arg0 context.Context, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SecurityGroupExists", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SecurityGroupExists indicates an expected call of SecurityGroupExists
func (mr *MockAPIMockRecorder) SecurityGroupExists(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SecurityGroupExists", reflect.TypeOf((*MockAPI)(nil).SecurityGroupExists), arg0, arg1)
}
// StackExists mocks base method
func (m *MockAPI) StackExists(arg0 context.Context, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StackExists", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// StackExists indicates an expected call of StackExists
func (mr *MockAPIMockRecorder) StackExists(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StackExists", reflect.TypeOf((*MockAPI)(nil).StackExists), arg0, arg1)
}
// UpdateStack mocks base method
func (m *MockAPI) UpdateStack(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateStack", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateStack indicates an expected call of UpdateStack
func (mr *MockAPIMockRecorder) UpdateStack(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateStack", reflect.TypeOf((*MockAPI)(nil).UpdateStack), arg0, arg1)
}
// WaitStackComplete mocks base method
func (m *MockAPI) WaitStackComplete(arg0 context.Context, arg1 string, arg2 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WaitStackComplete", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// WaitStackComplete indicates an expected call of WaitStackComplete
func (mr *MockAPIMockRecorder) WaitStackComplete(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitStackComplete", reflect.TypeOf((*MockAPI)(nil).WaitStackComplete), arg0, arg1, arg2)
}
// WithVolumeSecurityGroups mocks base method
func (m *MockAPI) WithVolumeSecurityGroups(arg0 context.Context, arg1 string, arg2 func([]string) error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithVolumeSecurityGroups", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// WithVolumeSecurityGroups indicates an expected call of WithVolumeSecurityGroups
func (mr *MockAPIMockRecorder) WithVolumeSecurityGroups(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithVolumeSecurityGroups", reflect.TypeOf((*MockAPI)(nil).WithVolumeSecurityGroups), arg0, arg1, arg2)
}
// getURLWithPortMapping mocks base method
func (m *MockAPI) getURLWithPortMapping(arg0 context.Context, arg1 []string) ([]compose.PortPublisher, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "getURLWithPortMapping", arg0, arg1)
ret0, _ := ret[0].([]compose.PortPublisher)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// getURLWithPortMapping indicates an expected call of getURLWithPortMapping
func (mr *MockAPIMockRecorder) getURLWithPortMapping(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getURLWithPortMapping", reflect.TypeOf((*MockAPI)(nil).getURLWithPortMapping), arg0, arg1)
}

View File

@ -77,14 +77,14 @@ func getEcsAPIService(ecsCtx store.EcsContext) (*ecsAPIService, error) {
return &ecsAPIService{
ctx: ecsCtx,
Region: ecsCtx.Region,
SDK: sdk,
aws: sdk,
}, nil
}
type ecsAPIService struct {
ctx store.EcsContext
Region string
SDK sdk
aws API
}
func (a *ecsAPIService) ContainerService() containers.Service {

View File

@ -38,6 +38,15 @@ import (
)
func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]byte, error) {
template, err := b.convert(ctx, project)
if err != nil {
return nil, err
}
return marshall(template)
}
func (b *ecsAPIService) convert(ctx context.Context, project *types.Project) (*cloudformation.Template, error) {
err := b.checkCompatibility(project)
if err != nil {
return nil, err
@ -48,32 +57,6 @@ func (b *ecsAPIService) Convert(ctx context.Context, project *types.Project) ([]
return nil, err
}
template, err := b.convert(project, resources)
if err != nil {
return nil, err
}
// Create a NFS inbound rule on each mount target for volumes
// as "source security group" use an arbitrary network attached to service(s) who mounts target volume
for n, vol := range project.Volumes {
err := b.SDK.WithVolumeSecurityGroups(ctx, vol.Name, func(securityGroups []string) error {
return b.createNFSmountIngress(securityGroups, project, n, template)
})
if err != nil {
return nil, err
}
}
err = b.createCapacityProvider(ctx, project, template, resources)
if err != nil {
return nil, err
}
return marshall(template)
}
// Convert a compose project into a CloudFormation template
func (b *ecsAPIService) convert(project *types.Project, resources awsResources) (*cloudformation.Template, error) {
template := cloudformation.NewTemplate()
b.ensureResources(&resources, project, template)
@ -90,105 +73,130 @@ func (b *ecsAPIService) convert(project *types.Project, resources awsResources)
b.createCloudMap(project, template, resources.vpc)
for _, service := range project.Services {
taskExecutionRole := b.createTaskExecutionRole(project, service, template)
taskRole := b.createTaskRole(project, service, template)
definition, err := b.createTaskDefinition(project, service)
err := b.createService(project, service, template, resources)
if err != nil {
return nil, err
}
definition.ExecutionRoleArn = cloudformation.Ref(taskExecutionRole)
if taskRole != "" {
definition.TaskRoleArn = cloudformation.Ref(taskRole)
}
taskDefinition := fmt.Sprintf("%sTaskDefinition", normalizeResourceName(service.Name))
template.Resources[taskDefinition] = definition
var healthCheck *cloudmap.Service_HealthCheckConfig
serviceRegistry := b.createServiceRegistry(service, template, healthCheck)
var (
dependsOn []string
serviceLB []ecs.Service_LoadBalancer
)
for _, port := range service.Ports {
for net := range service.Networks {
b.createIngress(service, net, port, template, resources)
}
protocol := strings.ToUpper(port.Protocol)
if resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
// we don't set Https as a certificate must be specified for HTTPS listeners
protocol = elbv2.ProtocolEnumHttp
}
targetGroupName := b.createTargetGroup(project, service, port, template, protocol, resources.vpc)
listenerName := b.createListener(service, port, template, targetGroupName, resources.loadBalancer, protocol)
dependsOn = append(dependsOn, listenerName)
serviceLB = append(serviceLB, ecs.Service_LoadBalancer{
ContainerName: service.Name,
ContainerPort: int(port.Target),
TargetGroupArn: cloudformation.Ref(targetGroupName),
})
}
desiredCount := 1
if service.Deploy != nil && service.Deploy.Replicas != nil {
desiredCount = int(*service.Deploy.Replicas)
}
for dependency := range service.DependsOn {
dependsOn = append(dependsOn, serviceResourceName(dependency))
}
minPercent, maxPercent, err := computeRollingUpdateLimits(service)
if err != nil {
return nil, err
}
assignPublicIP := ecsapi.AssignPublicIpEnabled
launchType := ecsapi.LaunchTypeFargate
platformVersion := "1.4.0" // LATEST which is set to 1.3.0 (?) which doesnt allow efs volumes.
if requireEC2(service) {
assignPublicIP = ecsapi.AssignPublicIpDisabled
launchType = ecsapi.LaunchTypeEc2
platformVersion = "" // The platform version must be null when specifying an EC2 launch type
}
template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
AWSCloudFormationDependsOn: dependsOn,
Cluster: resources.cluster,
DesiredCount: desiredCount,
DeploymentController: &ecs.Service_DeploymentController{
Type: ecsapi.DeploymentControllerTypeEcs,
},
DeploymentConfiguration: &ecs.Service_DeploymentConfiguration{
MaximumPercent: maxPercent,
MinimumHealthyPercent: minPercent,
},
LaunchType: launchType,
// TODO we miss support for https://github.com/aws/containers-roadmap/issues/631 to select a capacity provider
LoadBalancers: serviceLB,
NetworkConfiguration: &ecs.Service_NetworkConfiguration{
AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
AssignPublicIp: assignPublicIP,
SecurityGroups: resources.serviceSecurityGroups(service),
Subnets: resources.subnets,
},
},
PlatformVersion: platformVersion,
PropagateTags: ecsapi.PropagateTagsService,
SchedulingStrategy: ecsapi.SchedulingStrategyReplica,
ServiceRegistries: []ecs.Service_ServiceRegistry{serviceRegistry},
Tags: serviceTags(project, service),
TaskDefinition: cloudformation.Ref(normalizeResourceName(taskDefinition)),
}
b.createAutoscalingPolicy(project, resources, template, service)
}
// Create a NFS inbound rule on each mount target for volumes
// as "source security group" use an arbitrary network attached to service(s) who mounts target volume
for n, vol := range project.Volumes {
err := b.aws.WithVolumeSecurityGroups(ctx, vol.Name, func(securityGroups []string) error {
return b.createNFSmountIngress(securityGroups, project, n, template)
})
if err != nil {
return nil, err
}
}
err = b.createCapacityProvider(ctx, project, template, resources)
if err != nil {
return nil, err
}
return template, nil
}
func (b *ecsAPIService) createService(project *types.Project, service types.ServiceConfig, template *cloudformation.Template, resources awsResources) error {
taskExecutionRole := b.createTaskExecutionRole(project, service, template)
taskRole := b.createTaskRole(project, service, template)
definition, err := b.createTaskDefinition(project, service)
if err != nil {
return err
}
definition.ExecutionRoleArn = cloudformation.Ref(taskExecutionRole)
if taskRole != "" {
definition.TaskRoleArn = cloudformation.Ref(taskRole)
}
taskDefinition := fmt.Sprintf("%sTaskDefinition", normalizeResourceName(service.Name))
template.Resources[taskDefinition] = definition
var healthCheck *cloudmap.Service_HealthCheckConfig
serviceRegistry := b.createServiceRegistry(service, template, healthCheck)
var (
dependsOn []string
serviceLB []ecs.Service_LoadBalancer
)
for _, port := range service.Ports {
for net := range service.Networks {
b.createIngress(service, net, port, template, resources)
}
protocol := strings.ToUpper(port.Protocol)
if resources.loadBalancerType == elbv2.LoadBalancerTypeEnumApplication {
// we don't set Https as a certificate must be specified for HTTPS listeners
protocol = elbv2.ProtocolEnumHttp
}
targetGroupName := b.createTargetGroup(project, service, port, template, protocol, resources.vpc)
listenerName := b.createListener(service, port, template, targetGroupName, resources.loadBalancer, protocol)
dependsOn = append(dependsOn, listenerName)
serviceLB = append(serviceLB, ecs.Service_LoadBalancer{
ContainerName: service.Name,
ContainerPort: int(port.Target),
TargetGroupArn: cloudformation.Ref(targetGroupName),
})
}
desiredCount := 1
if service.Deploy != nil && service.Deploy.Replicas != nil {
desiredCount = int(*service.Deploy.Replicas)
}
for dependency := range service.DependsOn {
dependsOn = append(dependsOn, serviceResourceName(dependency))
}
minPercent, maxPercent, err := computeRollingUpdateLimits(service)
if err != nil {
return err
}
assignPublicIP := ecsapi.AssignPublicIpEnabled
launchType := ecsapi.LaunchTypeFargate
platformVersion := "1.4.0" // LATEST which is set to 1.3.0 (?) which doesnt allow efs volumes.
if requireEC2(service) {
assignPublicIP = ecsapi.AssignPublicIpDisabled
launchType = ecsapi.LaunchTypeEc2
platformVersion = "" // The platform version must be null when specifying an EC2 launch type
}
template.Resources[serviceResourceName(service.Name)] = &ecs.Service{
AWSCloudFormationDependsOn: dependsOn,
Cluster: resources.cluster,
DesiredCount: desiredCount,
DeploymentController: &ecs.Service_DeploymentController{
Type: ecsapi.DeploymentControllerTypeEcs,
},
DeploymentConfiguration: &ecs.Service_DeploymentConfiguration{
MaximumPercent: maxPercent,
MinimumHealthyPercent: minPercent,
},
LaunchType: launchType,
// TODO we miss support for https://github.com/aws/containers-roadmap/issues/631 to select a capacity provider
LoadBalancers: serviceLB,
NetworkConfiguration: &ecs.Service_NetworkConfiguration{
AwsvpcConfiguration: &ecs.Service_AwsVpcConfiguration{
AssignPublicIp: assignPublicIP,
SecurityGroups: resources.serviceSecurityGroups(service),
Subnets: resources.subnets,
},
},
PlatformVersion: platformVersion,
PropagateTags: ecsapi.PropagateTagsService,
SchedulingStrategy: ecsapi.SchedulingStrategyReplica,
ServiceRegistries: []ecs.Service_ServiceRegistry{serviceRegistry},
Tags: serviceTags(project, service),
TaskDefinition: cloudformation.Ref(normalizeResourceName(taskDefinition)),
}
return nil
}
const allProtocols = "-1"
func (b *ecsAPIService) createIngress(service types.ServiceConfig, net string, port types.ServicePortConfig, template *cloudformation.Template, resources awsResources) {

View File

@ -17,10 +17,14 @@
package ecs
import (
"context"
"fmt"
"io/ioutil"
"reflect"
"testing"
"github.com/golang/mock/gomock"
"github.com/docker/compose-cli/api/compose"
"github.com/aws/aws-sdk-go/service/elbv2"
@ -30,7 +34,6 @@ import (
"github.com/awslabs/goformation/v4/cloudformation/elasticloadbalancingv2"
"github.com/awslabs/goformation/v4/cloudformation/iam"
"github.com/awslabs/goformation/v4/cloudformation/logs"
"github.com/compose-spec/compose-go/cli"
"github.com/compose-spec/compose-go/loader"
"github.com/compose-spec/compose-go/types"
"gotest.tools/v3/assert"
@ -38,8 +41,12 @@ import (
)
func TestSimpleConvert(t *testing.T) {
project := load(t, "testdata/input/simple-single-service.yaml")
result := convertResultAsString(t, project)
bytes, err := ioutil.ReadFile("testdata/input/simple-single-service.yaml")
assert.NilError(t, err)
template := convertYaml(t, string(bytes), useDefaultVPC)
resultAsJSON, err := marshall(template)
assert.NilError(t, err)
result := fmt.Sprintf("%s\n", string(resultAsJSON))
expected := "simple/simple-cloudformation-conversion.golden"
golden.Assert(t, result, expected)
}
@ -54,7 +61,7 @@ services:
awslogs-datetime-pattern: "FOO"
x-aws-logs_retention: 10
`)
`, useDefaultVPC)
def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
logging := getMainContainer(def, t).LogConfiguration
if logging != nil {
@ -74,7 +81,7 @@ services:
image: hello_world
env_file:
- testdata/input/envfile
`)
`, useDefaultVPC)
def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
env := getMainContainer(def, t).Environment
var found bool
@ -96,7 +103,7 @@ services:
- testdata/input/envfile
environment:
- "FOO=ZOT"
`)
`, useDefaultVPC)
def := template.Resources["FooTaskDefinition"].(*ecs.TaskDefinition)
env := getMainContainer(def, t).Environment
var found bool
@ -118,7 +125,7 @@ services:
replicas: 4
update_config:
parallelism: 2
`)
`, useDefaultVPC)
service := template.Resources["FooService"].(*ecs.Service)
assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 150)
assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 50)
@ -133,7 +140,7 @@ services:
update_config:
x-aws-min_percent: 25
x-aws-max_percent: 125
`)
`, useDefaultVPC)
service := template.Resources["FooService"].(*ecs.Service)
assert.Check(t, service.DeploymentConfiguration.MaximumPercent == 125)
assert.Check(t, service.DeploymentConfiguration.MinimumHealthyPercent == 25)
@ -145,7 +152,7 @@ services:
foo:
image: hello_world
x-aws-pull_credentials: "secret"
`)
`, useDefaultVPC)
x := template.Resources["FooTaskExecutionRole"]
assert.Check(t, x != nil)
role := *(x.(*iam.Role))
@ -173,7 +180,7 @@ networks:
name: public
back-tier:
internal: true
`)
`, useDefaultVPC)
assert.Check(t, template.Resources["FronttierNetwork"] != nil)
assert.Check(t, template.Resources["BacktierNetwork"] != nil)
assert.Check(t, template.Resources["BacktierNetworkIngress"] != nil)
@ -201,7 +208,7 @@ func TestLoadBalancerTypeApplication(t *testing.T) {
`,
}
for _, y := range cases {
template := convertYaml(t, y)
template := convertYaml(t, y, useDefaultVPC)
lb := template.Resources["LoadBalancer"]
assert.Check(t, lb != nil)
loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
@ -218,7 +225,7 @@ services:
image: nginx
foo:
image: bar
`)
`, useDefaultVPC)
for _, r := range template.Resources {
assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::TargetGroup")
assert.Check(t, r.AWSCloudFormationType() != "AWS::ElasticLoadBalancingV2::Listener")
@ -233,7 +240,7 @@ services:
image: nginx
deploy:
replicas: 10
`)
`, useDefaultVPC)
s := template.Resources["TestService"]
assert.Check(t, s != nil)
service := *s.(*ecs.Service)
@ -245,7 +252,7 @@ func TestTaskSizeConvert(t *testing.T) {
services:
test:
image: nginx
`)
`, useDefaultVPC)
def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
assert.Equal(t, def.Cpu, "256")
assert.Equal(t, def.Memory, "512")
@ -259,7 +266,7 @@ services:
limits:
cpus: '0.5'
memory: 2048M
`)
`, useDefaultVPC)
def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
assert.Equal(t, def.Cpu, "512")
assert.Equal(t, def.Memory, "2048")
@ -273,7 +280,7 @@ services:
limits:
cpus: '4'
memory: 8192M
`)
`, useDefaultVPC)
def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
assert.Equal(t, def.Cpu, "4096")
assert.Equal(t, def.Memory, "8192")
@ -292,7 +299,7 @@ services:
- discrete_resource_spec:
kind: gpus
value: 2
`)
`, useDefaultVPC, useGPU)
def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
assert.Equal(t, def.Cpu, "4000")
assert.Equal(t, def.Memory, "792")
@ -308,26 +315,11 @@ services:
- discrete_resource_spec:
kind: gpus
value: 2
`)
`, useDefaultVPC, useGPU)
def = template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
assert.Equal(t, def.Cpu, "")
assert.Equal(t, def.Memory, "")
}
func TestTaskSizeConvertFailure(t *testing.T) {
model := loadConfig(t, `
services:
test:
image: nginx
deploy:
resources:
limits:
cpus: '0.5'
memory: 2043248M
`)
backend := &ecsAPIService{}
_, err := backend.convert(model, awsResources{})
assert.ErrorContains(t, err, "the resources requested are not supported by ECS/Fargate")
}
func TestLoadBalancerTypeNetwork(t *testing.T) {
template := convertYaml(t, `
@ -337,13 +329,32 @@ services:
ports:
- 80:80
- 88:88
`)
`, useDefaultVPC)
lb := template.Resources["LoadBalancer"]
assert.Check(t, lb != nil)
loadBalancer := *lb.(*elasticloadbalancingv2.LoadBalancer)
assert.Check(t, loadBalancer.Type == elbv2.LoadBalancerTypeEnumNetwork)
}
func TestUseCustomNetwork(t *testing.T) {
template := convertYaml(t, `
services:
test:
image: nginx
networks:
default:
external: true
name: sg-123abc
`, useDefaultVPC, func(m *MockAPIMockRecorder) {
m.SecurityGroupExists(gomock.Any(), "sg-123abc").Return(true, nil)
})
assert.Check(t, template.Resources["DefaultNetwork"] == nil)
assert.Check(t, template.Resources["DefaultNetworkIngress"] == nil)
s := template.Resources["TestService"].(*ecs.Service)
assert.Check(t, s != nil)
assert.Check(t, s.NetworkConfiguration.AwsvpcConfiguration.SecurityGroups[0] == "sg-123abc") //nolint:staticcheck
}
func TestServiceMapping(t *testing.T) {
template := convertYaml(t, `
services:
@ -360,7 +371,7 @@ services:
init: true
user: "user"
working_dir: "working_dir"
`)
`, useDefaultVPC)
def := template.Resources["TestTaskDefinition"].(*ecs.TaskDefinition)
container := getMainContainer(def, t)
assert.Equal(t, container.Image, "image")
@ -391,7 +402,7 @@ services:
ports:
- 80:80
- 88:88
`)
`, useDefaultVPC)
for _, r := range template.Resources {
tags := reflect.Indirect(reflect.ValueOf(r)).FieldByName("Tags")
if !tags.IsValid() {
@ -401,38 +412,26 @@ services:
k := tags.Index(i).FieldByName("Key").String()
v := tags.Index(i).FieldByName("Value").String()
if k == compose.ProjectTag {
assert.Equal(t, v, "Test")
assert.Equal(t, v, t.Name())
}
}
}
}
func convertResultAsString(t *testing.T, project *types.Project) string {
backend := &ecsAPIService{}
template, err := backend.convert(project, awsResources{
vpc: "vpcID",
subnets: []string{"subnet1", "subnet2"},
})
assert.NilError(t, err)
resultAsJSON, err := marshall(template)
assert.NilError(t, err)
return fmt.Sprintf("%s\n", string(resultAsJSON))
}
func load(t *testing.T, paths ...string) *types.Project {
options := cli.ProjectOptions{
Name: t.Name(),
ConfigPaths: paths,
}
project, err := cli.ProjectFromOptions(&options)
assert.NilError(t, err)
return project
}
func convertYaml(t *testing.T, yaml string) *cloudformation.Template {
func convertYaml(t *testing.T, yaml string, fn ...func(m *MockAPIMockRecorder)) *cloudformation.Template {
project := loadConfig(t, yaml)
backend := &ecsAPIService{}
template, err := backend.convert(project, awsResources{})
ctrl := gomock.NewController(t)
defer ctrl.Finish()
m := NewMockAPI(ctrl)
for _, f := range fn {
f(m.EXPECT())
}
backend := &ecsAPIService{
aws: m,
}
template, err := backend.convert(context.TODO(), project)
assert.NilError(t, err)
return template
}
@ -445,7 +444,7 @@ func loadConfig(t *testing.T, yaml string) *types.Project {
{Config: dict},
},
}, func(options *loader.Options) {
options.Name = "Test"
options.Name = t.Name()
})
assert.NilError(t, err)
return model
@ -460,3 +459,12 @@ func getMainContainer(def *ecs.TaskDefinition, t *testing.T) ecs.TaskDefinition_
t.Fail()
return def.ContainerDefinitions[0]
}
func useDefaultVPC(m *MockAPIMockRecorder) {
m.GetDefaultVPC(gomock.Any()).Return("vpc-123", nil)
m.GetSubNets(gomock.Any(), "vpc-123").Return([]string{"subnet1", "subnet2"}, nil)
}
func useGPU(m *MockAPIMockRecorder) {
m.GetParameter(gomock.Any(), gomock.Any()).Return("", nil)
}

View File

@ -97,6 +97,7 @@ var compatibleComposeAttributes = []string{
"secrets.file",
"volumes",
"volumes.external",
"networks.external",
}
func (c *fargateCompatibilityChecker) CheckImage(service *types.ServiceConfig) {

View File

@ -23,17 +23,17 @@ import (
)
func (b *ecsAPIService) Down(ctx context.Context, project string) error {
resources, err := b.SDK.ListStackResources(ctx, project)
resources, err := b.aws.ListStackResources(ctx, project)
if err != nil {
return err
}
err = resources.apply(awsTypeCapacityProvider, delete(ctx, b.SDK.DeleteCapacityProvider))
err = resources.apply(awsTypeCapacityProvider, delete(ctx, b.aws.DeleteCapacityProvider))
if err != nil {
return err
}
err = resources.apply(awsTypeAutoscalingGroup, delete(ctx, b.SDK.DeleteAutoscalingGroup))
err = resources.apply(awsTypeAutoscalingGroup, delete(ctx, b.aws.DeleteAutoscalingGroup))
if err != nil {
return err
}
@ -43,7 +43,7 @@ func (b *ecsAPIService) Down(ctx context.Context, project string) error {
return err
}
err = b.SDK.DeleteStack(ctx, project)
err = b.aws.DeleteStack(ctx, project)
if err != nil {
return err
}
@ -51,7 +51,7 @@ func (b *ecsAPIService) Down(ctx context.Context, project string) error {
}
func (b *ecsAPIService) previousStackEvents(ctx context.Context, project string) ([]string, error) {
events, err := b.SDK.DescribeStackEvents(ctx, project)
events, err := b.aws.DescribeStackEvents(ctx, project)
if err != nil {
return nil, err
}

View File

@ -41,7 +41,7 @@ func (b *ecsAPIService) createCapacityProvider(ctx context.Context, project *typ
return nil
}
ami, err := b.SDK.GetParameter(ctx, "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended")
ami, err := b.aws.GetParameter(ctx, "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended")
if err != nil {
return err
}

View File

@ -24,7 +24,7 @@ import (
)
func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Stack, error) {
stacks, err := b.SDK.ListStacks(ctx, project)
stacks, err := b.aws.ListStacks(ctx, project)
if err != nil {
return nil, err
}
@ -42,7 +42,7 @@ func (b *ecsAPIService) List(ctx context.Context, project string) ([]compose.Sta
}
func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error {
resources, err := b.SDK.ListStackResources(ctx, name)
resources, err := b.aws.ListStackResources(ctx, name)
if err != nil {
return err
}
@ -65,7 +65,7 @@ func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error
if len(svcArns) == 0 {
return nil
}
services, err := b.SDK.GetServiceTaskDefinition(ctx, cluster, svcArns)
services, err := b.aws.GetServiceTaskDefinition(ctx, cluster, svcArns)
if err != nil {
return err
}
@ -78,14 +78,14 @@ func (b *ecsAPIService) checkStackState(ctx context.Context, name string) error
}
func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, service string, taskdef string) error {
runningTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, false)
runningTasks, err := b.aws.GetServiceTasks(ctx, cluster, service, false)
if err != nil {
return err
}
if len(runningTasks) > 0 {
return nil
}
stoppedTasks, err := b.SDK.GetServiceTasks(ctx, cluster, service, true)
stoppedTasks, err := b.aws.GetServiceTasks(ctx, cluster, service, true)
if err != nil {
return err
}
@ -102,7 +102,7 @@ func (b *ecsAPIService) checkServiceState(ctx context.Context, cluster string, s
if len(tasks) == 0 {
return nil
}
reason, err := b.SDK.GetTaskStoppedReason(ctx, cluster, tasks[0])
reason, err := b.aws.GetTaskStoppedReason(ctx, cluster, tasks[0])
if err != nil {
return err
}

View File

@ -31,7 +31,7 @@ func (b *ecsAPIService) Logs(ctx context.Context, project string, w io.Writer) e
width: 0,
writer: w,
}
err := b.SDK.GetLogs(ctx, project, consumer.Log)
err := b.aws.GetLogs(ctx, project, consumer.Log)
return err
}

View File

@ -25,11 +25,11 @@ import (
)
func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.ServiceStatus, error) {
cluster, err := b.SDK.GetStackClusterID(ctx, project)
cluster, err := b.aws.GetStackClusterID(ctx, project)
if err != nil {
return nil, err
}
servicesARN, err := b.SDK.ListStackServices(ctx, project)
servicesARN, err := b.aws.ListStackServices(ctx, project)
if err != nil {
return nil, err
}
@ -40,7 +40,7 @@ func (b *ecsAPIService) Ps(ctx context.Context, project string) ([]compose.Servi
status := []compose.ServiceStatus{}
for _, arn := range servicesARN {
state, err := b.SDK.DescribeService(ctx, cluster, arn)
state, err := b.aws.DescribeService(ctx, cluster, arn)
if err != nil {
return nil, err
}

View File

@ -23,9 +23,6 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/docker/compose-cli/api/compose"
"github.com/docker/compose-cli/api/secrets"
"github.com/docker/compose-cli/internal"
@ -51,6 +48,8 @@ import (
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
"github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
)
@ -68,6 +67,9 @@ type sdk struct {
AG autoscalingiface.AutoScalingAPI
}
// sdk implement API
var _ API = sdk{}
func newSDK(sess *session.Session) sdk {
sess.Handlers.Build.PushBack(func(r *request.Request) {
request.AddToUserAgent(r, internal.ECSUserAgentName+"/"+internal.Version)

View File

@ -23,17 +23,17 @@ import (
)
func (b *ecsAPIService) CreateSecret(ctx context.Context, secret secrets.Secret) (string, error) {
return b.SDK.CreateSecret(ctx, secret)
return b.aws.CreateSecret(ctx, secret)
}
func (b *ecsAPIService) InspectSecret(ctx context.Context, id string) (secrets.Secret, error) {
return b.SDK.InspectSecret(ctx, id)
return b.aws.InspectSecret(ctx, id)
}
func (b *ecsAPIService) ListSecrets(ctx context.Context) ([]secrets.Secret, error) {
return b.SDK.ListSecrets(ctx)
return b.aws.ListSecrets(ctx)
}
func (b *ecsAPIService) DeleteSecret(ctx context.Context, id string, recover bool) error {
return b.SDK.DeleteSecret(ctx, id, recover)
return b.aws.DeleteSecret(ctx, id, recover)
}

View File

@ -5,7 +5,7 @@
"Properties": {
"Description": "Service Map for Docker Compose project TestSimpleConvert",
"Name": "TestSimpleConvert.local",
"Vpc": "vpcID"
"Vpc": "vpc-123"
},
"Type": "AWS::ServiceDiscovery::PrivateDnsNamespace"
},
@ -47,7 +47,7 @@
"Value": "default"
}
],
"VpcId": "vpcID"
"VpcId": "vpc-123"
},
"Type": "AWS::EC2::SecurityGroup"
},
@ -218,7 +218,7 @@
}
],
"TargetType": "ip",
"VpcId": "vpcID"
"VpcId": "vpc-123"
},
"Type": "AWS::ElasticLoadBalancingV2::TargetGroup"
},

View File

@ -27,7 +27,7 @@ import (
)
func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach bool) error {
err := b.SDK.CheckRequirements(ctx, b.Region)
err := b.aws.CheckRequirements(ctx, b.Region)
if err != nil {
return err
}
@ -37,23 +37,23 @@ func (b *ecsAPIService) Up(ctx context.Context, project *types.Project, detach b
return err
}
update, err := b.SDK.StackExists(ctx, project.Name)
update, err := b.aws.StackExists(ctx, project.Name)
if err != nil {
return err
}
operation := stackCreate
if update {
operation = stackUpdate
changeset, err := b.SDK.CreateChangeSet(ctx, project.Name, template)
changeset, err := b.aws.CreateChangeSet(ctx, project.Name, template)
if err != nil {
return err
}
err = b.SDK.UpdateStack(ctx, changeset)
err = b.aws.UpdateStack(ctx, changeset)
if err != nil {
return err
}
} else {
err = b.SDK.CreateStack(ctx, project.Name, template)
err = b.aws.CreateStack(ctx, project.Name, template)
if err != nil {
return err
}

View File

@ -37,7 +37,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
// progress writer
w := progress.ContextWriter(ctx)
// Get the unique Stack ID so we can collect events without getting some from previous deployments with same name
stackID, err := b.SDK.GetStackID(ctx, name)
stackID, err := b.aws.GetStackID(ctx, name)
if err != nil {
return err
}
@ -45,7 +45,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
ticker := time.NewTicker(1 * time.Second)
done := make(chan bool)
go func() {
b.SDK.WaitStackComplete(ctx, stackID, operation) //nolint:errcheck
b.aws.WaitStackComplete(ctx, stackID, operation) //nolint:errcheck
ticker.Stop()
done <- true
}()
@ -58,7 +58,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
completed = true
case <-ticker.C:
}
events, err := b.SDK.DescribeStackEvents(ctx, stackID)
events, err := b.aws.DescribeStackEvents(ctx, stackID)
if err != nil {
return err
}
@ -111,7 +111,7 @@ func (b *ecsAPIService) WaitStackCompletion(ctx context.Context, name string, op
continue
}
if err := b.checkStackState(ctx, name); err != nil {
if e := b.SDK.DeleteStack(ctx, name); e != nil {
if e := b.aws.DeleteStack(ctx, name); e != nil {
return e
}
stackErr = err

1
go.mod
View File

@ -36,6 +36,7 @@ require (
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect
github.com/gobwas/pool v0.2.0 // indirect
github.com/gobwas/ws v1.0.4
github.com/golang/mock v1.4.4
github.com/golang/protobuf v1.4.2
github.com/google/go-cmp v0.5.2
github.com/google/uuid v1.1.2

1
go.sum
View File

@ -220,6 +220,7 @@ github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFU
github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.1/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=