Use LoadBalancer's VPC and subnet when x-aws-loadbalancer is set

Signed-off-by: Nicolas De Loof <nicolas.deloof@gmail.com>
This commit is contained in:
Nicolas De Loof 2021-01-12 14:57:28 +01:00
parent f6e5c911ce
commit 075f54713e
No known key found for this signature in database
GPG Key ID: 9858809D6F8F6E7E
4 changed files with 58 additions and 33 deletions

View File

@ -40,7 +40,7 @@ type API interface {
CheckVPC(ctx context.Context, vpcID string) error
GetDefaultVPC(ctx context.Context) (string, error)
GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error)
IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error)
IsPublicSubnet(ctx context.Context, subNetID string) (bool, error)
GetRoleArn(ctx context.Context, name string) (string, error)
StackExists(ctx context.Context, name string) (bool, error)
CreateStack(ctx context.Context, name string, region string, template []byte) error
@ -68,7 +68,7 @@ type API interface {
getURLWithPortMapping(ctx context.Context, targetGroupArns []string) ([]compose.PortPublisher, error)
ListTasks(ctx context.Context, cluster string, family string) ([]string, error)
GetPublicIPs(ctx context.Context, interfaces ...string) (map[string]string, error)
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, error)
ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error)
GetLoadBalancerURL(ctx context.Context, arn string) (string, error)
GetParameter(ctx context.Context, name string) (string, error)
SecurityGroupExists(ctx context.Context, sg string) (bool, error)

View File

@ -129,11 +129,11 @@ func (b *ecsAPIService) parse(ctx context.Context, project *types.Project, templ
if err != nil {
return r, err
}
r.vpc, r.subnets, err = b.parseVPCExtension(ctx, project)
err = b.parseLoadBalancerExtension(ctx, project, &r)
if err != nil {
return r, err
}
r.loadBalancer, r.loadBalancerType, err = b.parseLoadBalancerExtension(ctx, project)
err = b.parseVPCExtension(ctx, project, &r)
if err != nil {
return r, err
}
@ -165,7 +165,7 @@ func (b *ecsAPIService) parseClusterExtension(ctx context.Context, project *type
return nil, nil
}
func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project) (string, []awsResource, error) {
func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Project, r *awsResources) error {
var vpc string
if x, ok := project.Extensions[extensionVPC]; ok {
vpc = x.(string)
@ -177,29 +177,40 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
vpc = id[i+1:]
}
if r.vpc != "" {
if r.vpc != vpc {
return fmt.Errorf("load balancer set by %s is attached to VPC %s", extensionLoadBalancer, r.vpc)
}
return nil
}
err = b.aws.CheckVPC(ctx, vpc)
if err != nil {
return "", nil, err
return err
}
} else {
if r.vpc != "" {
return nil
}
defaultVPC, err := b.aws.GetDefaultVPC(ctx)
if err != nil {
return "", nil, err
return err
}
vpc = defaultVPC
}
subNets, err := b.aws.GetSubNets(ctx, vpc)
if err != nil {
return "", nil, err
return err
}
var publicSubNets []awsResource
for _, subNet := range subNets {
isPublic, err := b.aws.IsPublicSubnet(ctx, vpc, subNet.ID())
isPublic, err := b.aws.IsPublicSubnet(ctx, subNet.ID())
if err != nil {
return "", nil, err
return err
}
if isPublic {
publicSubNets = append(publicSubNets, subNet)
@ -207,27 +218,34 @@ func (b *ecsAPIService) parseVPCExtension(ctx context.Context, project *types.Pr
}
if len(publicSubNets) < 2 {
return "", nil, fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
return fmt.Errorf("VPC %s should have at least 2 associated public subnets in different availability zones", vpc)
}
return vpc, publicSubNets, nil
r.vpc = vpc
r.subnets = subNets
return nil
}
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project) (awsResource, string, error) {
func (b *ecsAPIService) parseLoadBalancerExtension(ctx context.Context, project *types.Project, r *awsResources) error {
if x, ok := project.Extensions[extensionLoadBalancer]; ok {
nameOrArn := x.(string)
loadBalancer, loadBalancerType, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
loadBalancer, loadBalancerType, vpc, subnets, err := b.aws.ResolveLoadBalancer(ctx, nameOrArn)
if err != nil {
return nil, "", err
return err
}
required := getRequiredLoadBalancerType(project)
if loadBalancerType != required {
return nil, "", fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
return fmt.Errorf("load balancer %q is of type %s, project require a %s", nameOrArn, loadBalancerType, required)
}
return loadBalancer, loadBalancerType, err
r.loadBalancer = loadBalancer
r.loadBalancerType = loadBalancerType
r.vpc = vpc
r.subnets = subnets
return err
}
return nil, "", nil
return nil
}
func (b *ecsAPIService) parseExternalNetworks(ctx context.Context, project *types.Project) (map[string]string, error) {

View File

@ -6,13 +6,12 @@ package ecs
import (
context "context"
reflect "reflect"
cloudformation "github.com/aws/aws-sdk-go/service/cloudformation"
ecs "github.com/aws/aws-sdk-go/service/ecs"
compose "github.com/docker/compose-cli/api/compose"
secrets "github.com/docker/compose-cli/api/secrets"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockAPI is a mock of API interface
@ -455,7 +454,7 @@ func (mr *MockAPIMockRecorder) InspectSecret(arg0, arg1 interface{}) *gomock.Cal
}
// IsPublicSubnet mocks base method
func (m *MockAPI) IsPublicSubnet(ctx context.Context, arg0 string, arg1 string) (bool, error) {
func (m *MockAPI) IsPublicSubnet(arg0 context.Context, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsPublicSubnet", arg0, arg1)
ret0, _ := ret[0].(bool)
@ -605,13 +604,15 @@ func (mr *MockAPIMockRecorder) ResolveFileSystem(arg0, arg1 interface{}) *gomock
}
// ResolveLoadBalancer mocks base method
func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, error) {
func (m *MockAPI) ResolveLoadBalancer(arg0 context.Context, arg1 string) (awsResource, string, string, []awsResource, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveLoadBalancer", arg0, arg1)
ret0, _ := ret[0].(awsResource)
ret1, _ := ret[1].(string)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
ret2, _ := ret[2].(string)
ret3, _ := ret[3].([]awsResource)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// ResolveLoadBalancer indicates an expected call of ResolveLoadBalancer

View File

@ -210,7 +210,7 @@ func (s sdk) GetSubNets(ctx context.Context, vpcID string) ([]awsResource, error
return ids, nil
}
func (s sdk) IsPublicSubnet(ctx context.Context, vpcID string, subNetID string) (bool, error) {
func (s sdk) IsPublicSubnet(ctx context.Context, subNetID string) (bool, error) {
tables, err := s.EC2.DescribeRouteTablesWithContext(ctx, &ec2.DescribeRouteTablesInput{
Filters: []*ec2.Filter{
{
@ -1045,14 +1045,14 @@ func (s sdk) GetPublicIPs(ctx context.Context, interfaces ...string) (map[string
}
}
func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsResource, string, error) {
logrus.Debug("Check if LoadBalancer exists: ", nameOrarn)
func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrArn string) (awsResource, string, string, []awsResource, error) {
logrus.Debug("Check if LoadBalancer exists: ", nameOrArn)
var arns []*string
var names []*string
if arn.IsARN(nameOrarn) {
arns = append(arns, aws.String(nameOrarn))
if arn.IsARN(nameOrArn) {
arns = append(arns, aws.String(nameOrArn))
} else {
names = append(names, aws.String(nameOrarn))
names = append(names, aws.String(nameOrArn))
}
lbs, err := s.ELB.DescribeLoadBalancersWithContext(ctx, &elbv2.DescribeLoadBalancersInput{
@ -1060,16 +1060,22 @@ func (s sdk) ResolveLoadBalancer(ctx context.Context, nameOrarn string) (awsReso
Names: names,
})
if err != nil {
return nil, "", err
return nil, "", "", nil, err
}
if len(lbs.LoadBalancers) == 0 {
return nil, "", errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrarn)
return nil, "", "", nil, errors.Wrapf(errdefs.ErrNotFound, "load balancer %q does not exist", nameOrArn)
}
it := lbs.LoadBalancers[0]
var subNets []awsResource
for _, az := range it.AvailabilityZones {
subNets = append(subNets, existingAWSResource{
id: aws.StringValue(az.SubnetId),
})
}
return existingAWSResource{
arn: aws.StringValue(it.LoadBalancerArn),
id: aws.StringValue(it.LoadBalancerName),
}, aws.StringValue(it.Type), nil
}, aws.StringValue(it.Type), aws.StringValue(it.VpcId), subNets, nil
}
func (s sdk) GetLoadBalancerURL(ctx context.Context, arn string) (string, error) {