Merge pull request #1064 from docker/public_subnets

Only consider public subnets
This commit is contained in:
Nicolas De loof 2020-12-15 14:38:55 +01:00 committed by GitHub
commit 42acaea3c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 4 deletions

View File

@ -40,6 +40,7 @@ type API interface {
CheckVPC(ctx context.Context, vpcID string) error CheckVPC(ctx context.Context, vpcID string) error
GetDefaultVPC(ctx context.Context) (string, error) GetDefaultVPC(ctx context.Context) (string, error)
GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
GetRoleArn(ctx context.Context, name string) (string, error) GetRoleArn(ctx context.Context, name string) (string, error)
StackExists(ctx context.Context, name string) (bool, error) StackExists(ctx context.Context, name string) (bool, error)
CreateStack(ctx context.Context, name string, region string, template []byte) error CreateStack(ctx context.Context, name string, region string, template []byte) error

View File

@ -185,10 +185,22 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if len(subNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc) var publicSubNets []awsResource
for _, subNet := range subNets {
isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
if err != nil {
return "", nil, err
}
if isPublic {
publicSubNets = append(publicSubNets, subNet)
}
} }
return vpc, subNets, nil
if len(publicSubNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
}
return vpc, publicSubNets, nil
} }
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) { func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {

View File

@ -6,12 +6,13 @@ package ecs
import ( import (
context "context" context "context"
reflect "reflect"
cloudformation "github.com/aws/aws-sdk-go/service/cloudformation" cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
ecs "github.com/aws/aws-sdk-go/service/ecs" ecs "github.com/aws/aws-sdk-go/service/ecs"
compose "github.com/docker/compose-cli/api/compose" compose "github.com/docker/compose-cli/api/compose"
secrets "github.com/docker/compose-cli/api/secrets" secrets "github.com/docker/compose-cli/api/secrets"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
reflect "reflect"
) )
// MockAPI is a mock of API interface // MockAPI is a mock of API interface
@ -453,6 +454,21 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
} }
// IsPublicSubnet mocks base method
func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsPublicSubnet indicates an expected call of IsPublicSubnet
func (mr *MockAPIMockRecorder) IsPublicSubnet(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicSubnet", reflect.TypeOf((*MockAPI)(nil).IsPublicSubnet), arg0, arg1)
}
// ListFileSystems mocks base method // ListFileSystems mocks base method
func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) { func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -591,6 +591,8 @@ func useDefaultVPC(m *MockAPIMockRecorder) {
existingAWSResource{id: "subnet1"}, existingAWSResource{id: "subnet1"},
existingAWSResource{id: "subnet2"}, existingAWSResource{id: "subnet2"},
}, nil) }, nil)
m.IsPublicSubnet(gomock.Any(), "subnet1").Return(true, nil)
m.IsPublicSubnet(gomock.Any(), "subnet2").Return(true, nil)
} }
func useGPU(m *MockAPIMockRecorder) { func useGPU(m *MockAPIMockRecorder) {

View File

@ -211,6 +211,37 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
return ids, nil return ids, nil
} }
func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("association.subnet-id"),
Values: []*string{aws.String(subNetID)},
},
},
})
if err != nil {
return false, err
}
if len(tables.RouteTables) == 0 {
// If a subnet is not explicitly associated with any route table, it is implicitly associated with the main route table.
// https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-route-tables.html
return true, nil
}
for _, routeTable := range tables.RouteTables {
for _, route := range routeTable.Routes {
if aws.StringValue(route.State) != "active" {
continue
}
if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw-") {
// Connected to an internet Gateway
return true, nil
}
}
}
return false, nil
}
func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) { func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{ role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
RoleName: aws.String(name), RoleName: aws.String(name),