From cd27fcb9c8c030898cdbd10deb92ac29b8529354 Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Fri, 28 Aug 2020 10:57:31 +0200 Subject: [PATCH] Check VPC has DNS resolution enabled Signed-off-by: Nicolas De Loof --- ecs/sdk.go | 17 +++++++++++++---- ecs/up.go | 16 ++++++++-------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/ecs/sdk.go b/ecs/sdk.go index 19ffbe251..e0a45330f 100644 --- a/ecs/sdk.go +++ b/ecs/sdk.go @@ -105,10 +105,19 @@ func (s sdk) CreateCluster(ctx context.Context, name string) (string, error) { return *response.Cluster.Status, nil } -func (s sdk) VpcExists(ctx context.Context, vpcID string) (bool, error) { - logrus.Debug("CheckRequirements if VPC exists: ", vpcID) - _, err := s.EC2.DescribeVpcsWithContext(ctx, &ec2.DescribeVpcsInput{VpcIds: []*string{&vpcID}}) - return err == nil, err +func (s sdk) CheckVPC(ctx context.Context, vpcID string) error { + logrus.Debug("CheckRequirements on VPC : ", vpcID) + output, err := s.EC2.DescribeVpcAttributeWithContext(ctx, &ec2.DescribeVpcAttributeInput{ + VpcId: aws.String(vpcID), + Attribute: aws.String("enableDnsSupport"), + }) + if err != nil { + return err + } + if !*output.EnableDnsSupport.Value { + return fmt.Errorf("VPC %q doesn't have DNS resolution enabled", vpcID) + } + return err } func (s sdk) GetDefaultVPC(ctx context.Context) (string, error) { diff --git a/ecs/up.go b/ecs/up.go index cf16cadf9..6d20c5c41 100644 --- a/ecs/up.go +++ b/ecs/up.go @@ -103,23 +103,23 @@ func (b *ecsAPIService) Up(ctx context.Context, project *types.Project) error { } func (b ecsAPIService) GetVPC(ctx context.Context, project *types.Project) (string, error) { + var vpcID string //check compose file for custom VPC selected if vpc, ok := project.Extensions[extensionVPC]; ok { - vpcID := vpc.(string) - ok, err := b.SDK.VpcExists(ctx, vpcID) + vpcID = vpc.(string) + } else { + defaultVPC, err := b.SDK.GetDefaultVPC(ctx) if err != nil { return "", err } - if !ok { - return "", fmt.Errorf("VPC does not exist: %s", vpc) - } - return vpcID, nil + vpcID = defaultVPC } - defaultVPC, err := b.SDK.GetDefaultVPC(ctx) + + err := b.SDK.CheckVPC(ctx, vpcID) if err != nil { return "", err } - return defaultVPC, nil + return vpcID, nil } func (b ecsAPIService) GetLoadBalancer(ctx context.Context, project *types.Project) (string, error) {