Merge pull request #1064 from docker/public_subnets

Only consider public subnets
This commit is contained in:
Nicolas De loof 2020-12-15 14:38:55 +01:00 committed by GitHub
commit 42acaea3c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 4 deletions

View File

@ -40,6 +40,7 @@ type API interface {
CheckVPC(ctx context.Context, vpcID string) error
GetDefaultVPC(ctx context.Context) (string, error)
GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
GetRoleArn(ctx context.Context, name string) (string, error)
StackExists(ctx context.Context, name string) (bool, error)
CreateStack(ctx context.Context, name string, region string, template []byte) error

View File

@ -185,10 +185,22 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
if err != nil {
return "", nil, err
}
if len(subNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated subnets in different availability zones", vpc)
var publicSubNets []awsResource
for _, subNet := range subNets {
isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
if err != nil {
return "", nil, err
}
if isPublic {
publicSubNets = append(publicSubNets, subNet)
}
}
return vpc, subNets, nil
if len(publicSubNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
}
return vpc, publicSubNets, nil
}
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {

View File

@ -6,12 +6,13 @@ package ecs
import (
context "context"
reflect "reflect"
cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
ecs "github.com/aws/aws-sdk-go/service/ecs"
compose "github.com/docker/compose-cli/api/compose"
secrets "github.com/docker/compose-cli/api/secrets"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAPI is a mock of API interface
@ -453,6 +454,21 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InspectSecret", reflect.TypeOf((*MockAPI)(nil).InspectSecret), arg0, arg1)
}
// IsPublicSubnet mocks base method
func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsPublicSubnet indicates an expected call of IsPublicSubnet
func (mr *MockAPIMockRecorder) IsPublicSubnet(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPublicSubnet", reflect.TypeOf((*MockAPI)(nil).IsPublicSubnet), arg0, arg1)
}
// ListFileSystems mocks base method
func (m *MockAPI) ListFileSystems(arg0 context.Context, arg1 map[string]string) ([]awsResource, error) {
m.ctrl.T.Helper()

View File

@ -591,6 +591,8 @@ func useDefaultVPC(m *MockAPIMockRecorder) {
existingAWSResource{id: "subnet1"},
existingAWSResource{id: "subnet2"},
}, nil)
m.IsPublicSubnet(gomock.Any(), "subnet1").Return(true, nil)
m.IsPublicSubnet(gomock.Any(), "subnet2").Return(true, nil)
}
func useGPU(m *MockAPIMockRecorder) {

View File

@ -211,6 +211,37 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
return ids, nil
}
func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("association.subnet-id"),
Values: []*string{aws.String(subNetID)},
},
},
})
if err != nil {
return false, err
}
if len(tables.RouteTables) == 0 {
// If a subnet is not explicitly associated with any route table, it is implicitly associated with the main route table.
// https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-route-tables.html
return true, nil
}
for _, routeTable := range tables.RouteTables {
for _, route := range routeTable.Routes {
if aws.StringValue(route.State) != "active" {
continue
}
if strings.HasPrefix(aws.StringValue(route.GatewayId), "igw-") {
// Connected to an internet Gateway
return true, nil
}
}
}
return false, nil
}
func (s sdk) GetRoleArn(ctx context.Context, name string) (string, error) {
role, err := s.IAM.GetRoleWithContext(ctx, &iam.GetRoleInput{
RoleName: aws.String(name),