From b9a30258655f23135ab6df0343daae15b0810f07 Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Tue, 15 Dec 2020 12:33:53 +0100 Subject: [PATCH] Only consider public subnets Signed-off-by: Nicolas De Loof --- ecs/aws.go | 1 + ecs/awsResources.go | 18 +++++++++++++++--- ecs/aws_mock.go | 18 +++++++++++++++++- ecs/cloudformation_test.go | 2 ++ ecs/sdk.go | 31 +++++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/ecs/aws.go b/ecs/aws.go index 5cf50fbfa..e20fe66ad 100644 --- a/ecs/aws.go +++ b/ecs/aws.go @@ -40,6 +40,7 @@ type API interface { CheckVPC(ctx context.Context, vpcID string) error GetDefaultVPC(ctx context.Context) (string, error) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error) + IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) GetRoleArn(ctx context.Context, name string) (string, error) StackExists(ctx context.Context, name string) (bool, error) CreateStack(ctx context.Context, name string, region string, template []byte) error diff --git a/ecs/awsResources.go b/ecs/awsResources.go index 012b31822..d1ff7cf4f 100644 --- a/ecs/awsResources.go +++ b/ecs/awsResources.go @@ -185,10 +185,22 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr if err != nil { return "", nil, err } - if len(subNets) < 2 { - return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc) + + var publicSubNets []awsResource + for _, subNet := range subNets { + isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID()) + if err != nil { + return "", nil, err + } + if isPublic { + publicSubNets = append(publicSubNets, subNet) + } } - return vpc, subNets, nil + + if len(publicSubNets) < 2 { + return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc) + } + return vpc, publicSubNets, nil } func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) { diff --git a/ecs/aws_mock.go b/ecs/aws_mock.go index 9d438bf29..486648f44 100644 --- a/ecs/aws_mock.go +++ b/ecs/aws_mock.go @@ -6,12 +6,13 @@ package ecs import ( context "context" + reflect "reflect" + cloudformation "github.com/aws/aws-sdk-go/service/cloudformation" ecs "github.com/aws/aws-sdk-go/service/ecs" compose "github.com/docker/compose-cli/api/compose" secrets "github.com/docker/compose-cli/api/secrets" gomock "github.com/golang/mock/gomock" - reflect "reflect" ) // MockAPI is a mock of API interface @@ -453,6 +454,21 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1) } +// IsPublicSubnet mocks base method +func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsPublicSubnet indicates an expected call of IsPublicSubnet +func (mr *MockAPIMockRecorder) IsPublicSubnet(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicSubnet", reflect.TypeOf((*MockAPI)(nil).IsPublicSubnet), arg0, arg1) +} + // ListFileSystems mocks base method func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) { m.ctrl.T.Helper() diff --git a/ecs/cloudformation_test.go b/ecs/cloudformation_test.go index 275721c37..0021114a2 100644 --- a/ecs/cloudformation_test.go +++ b/ecs/cloudformation_test.go @@ -591,6 +591,8 @@ func useDefaultVPC(m *MockAPIMockRecorder) { existingAWSResource{id: "subnet1"}, existingAWSResource{id: "subnet2"}, }, nil) + m.IsPublicSubnet(gomock.Any(), "subnet1").Return(true, nil) + m.IsPublicSubnet(gomock.Any(), "subnet2").Return(true, nil) } func useGPU(m *MockAPIMockRecorder) { diff --git a/ecs/sdk.go b/ecs/sdk.go index b44ecbfab..a8e89946c 100644 --- a/ecs/sdk.go +++ b/ecs/sdk.go @@ -211,6 +211,37 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error return ids, nil } +func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) { + tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("association.subnet-id"), + Values: []*string{aws.String(subNetID)}, + }, + }, + }) + if err != nil { + return false, err + } + if len(tables.RouteTables) == 0 { + // If a subnet is not explicitly associated with any route table, it is implicitly associated with the main route table. + // https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-route-tables.html + return true, nil + } + for _, routeTable := range tables.RouteTables { + for _, route := range routeTable.Routes { + if aws.StringValue(route.State) != "active" { + continue + } + if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw-") { + // Connected to an internet Gateway + return true, nil + } + } + } + return false, nil +} + func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) { role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{ RoleName: aws.String(name),