diff --git a/.gitignore b/.gitignore index a5d8f7237..480478873 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ bin/ dist/ +/.vscode/ diff --git a/aci/backend.go b/aci/backend.go index 373326362..c9359c211 100644 --- a/aci/backend.go +++ b/aci/backend.go @@ -51,6 +51,7 @@ type LoginParams struct { TenantID string ClientID string ClientSecret string + CloudName string } // Validate returns an error if options are not used properly diff --git a/aci/backend_test.go b/aci/backend_test.go index 96432d903..909171657 100644 --- a/aci/backend_test.go +++ b/aci/backend_test.go @@ -23,7 +23,9 @@ import ( "github.com/stretchr/testify/mock" "gotest.tools/v3/assert" + "github.com/docker/compose-cli/aci/login" "github.com/docker/compose-cli/api/containers" + "golang.org/x/oauth2" ) func TestGetContainerName(t *testing.T) { @@ -82,7 +84,7 @@ func TestLoginParamsValidate(t *testing.T) { func TestLoginServicePrincipal(t *testing.T) { loginService := mockLoginService{} - loginService.On("LoginServicePrincipal", "someID", "secret", "tenant").Return(nil) + loginService.On("LoginServicePrincipal", "someID", "secret", "tenant", "someCloud").Return(nil) loginBackend := aciCloudService{ loginService: &loginService, } @@ -91,6 +93,7 @@ func TestLoginServicePrincipal(t *testing.T) { ClientID: "someID", ClientSecret: "secret", TenantID: "tenant", + CloudName: "someCloud", }) assert.NilError(t, err) @@ -99,13 +102,14 @@ func TestLoginServicePrincipal(t *testing.T) { func TestLoginWithTenant(t *testing.T) { loginService := mockLoginService{} ctx := context.Background() - loginService.On("Login", ctx, "tenant").Return(nil) + loginService.On("Login", ctx, "tenant", "someCloud").Return(nil) loginBackend := aciCloudService{ loginService: &loginService, } err := loginBackend.Login(ctx, LoginParams{ - TenantID: "tenant", + TenantID: "tenant", + CloudName: "someCloud", }) assert.NilError(t, err) @@ -114,12 +118,14 @@ func TestLoginWithTenant(t *testing.T) { func TestLoginWithoutTenant(t *testing.T) { loginService := mockLoginService{} ctx := context.Background() - loginService.On("Login", ctx, "").Return(nil) + loginService.On("Login", ctx, "", "someCloud").Return(nil) loginBackend := aciCloudService{ loginService: &loginService, } - err := loginBackend.Login(ctx, LoginParams{}) + err := loginBackend.Login(ctx, LoginParams{ + CloudName: "someCloud", + }) assert.NilError(t, err) } @@ -128,13 +134,13 @@ type mockLoginService struct { mock.Mock } -func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string) error { - args := s.Called(ctx, requestedTenantID) +func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error { + args := s.Called(ctx, requestedTenantID, cloudEnvironment) return args.Error(0) } -func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { - args := s.Called(clientID, clientSecret, tenantID) +func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error { + args := s.Called(clientID, clientSecret, tenantID, cloudEnvironment) return args.Error(0) } @@ -142,3 +148,18 @@ func (s *mockLoginService) Logout(ctx context.Context) error { args := s.Called(ctx) return args.Error(0) } + +func (s *mockLoginService) GetTenantID() (string, error) { + args := s.Called() + return args.String(0), args.Error(1) +} + +func (s *mockLoginService) GetCloudEnvironment() (login.CloudEnvironment, error) { + args := s.Called() + return args.Get(0).(login.CloudEnvironment), args.Error(1) +} + +func (s *mockLoginService) GetValidToken() (oauth2.Token, string, error) { + args := s.Called() + return args.Get(0).(oauth2.Token), args.String(1), args.Error(2) +} diff --git a/aci/cloud.go b/aci/cloud.go index d47c174ac..d9b13b1a0 100644 --- a/aci/cloud.go +++ b/aci/cloud.go @@ -25,7 +25,7 @@ import ( ) type aciCloudService struct { - loginService login.AzureLoginServiceAPI + loginService login.AzureLoginService } func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error { @@ -33,10 +33,13 @@ func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error if !ok { return errors.New("could not read Azure LoginParams struct from generic parameter") } - if opts.ClientID != "" { - return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID) + if opts.CloudName == "" { + opts.CloudName = login.AzurePublicCloudName } - return cs.loginService.Login(ctx, opts.TenantID) + if opts.ClientID != "" { + return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID, opts.CloudName) + } + return cs.loginService.Login(ctx, opts.TenantID, opts.CloudName) } func (cs *aciCloudService) Logout(ctx context.Context) error { diff --git a/aci/convert/registry_credentials.go b/aci/convert/registry_credentials.go index 8919fb53a..321c27b5b 100644 --- a/aci/convert/registry_credentials.go +++ b/aci/convert/registry_credentials.go @@ -47,7 +47,7 @@ const ( type registryHelper interface { getAllRegistryCredentials() (map[string]types.AuthConfig, error) - autoLoginAcr(registry string) error + autoLoginAcr(registry string, loginService login.AzureLoginService) error } type cliRegistryHelper struct { @@ -65,9 +65,19 @@ func newCliRegistryConfLoader() cliRegistryHelper { } func getRegistryCredentials(project compose.Project, helper registryHelper) ([]containerinstance.ImageRegistryCredential, error) { - usedRegistries, acrRegistries := getUsedRegistries(project) + loginService, err := login.NewAzureLoginService() + if err != nil { + return nil, err + } + + var cloudEnvironment *login.CloudEnvironment = nil + if ce, err := loginService.GetCloudEnvironment(); err != nil { + cloudEnvironment = &ce + } + + usedRegistries, acrRegistries := getUsedRegistries(project, cloudEnvironment) for _, registry := range acrRegistries { - err := helper.autoLoginAcr(registry) + err := helper.autoLoginAcr(registry, loginService) if err != nil { fmt.Printf("WARNING: %v\n", err) fmt.Printf("Could not automatically login to %s from your Azure login. Assuming you already logged in to the ACR registry\n", registry) @@ -116,9 +126,10 @@ func getRegistryCredentials(project compose.Project, helper registryHelper) ([]c return registryCreds, nil } -func getUsedRegistries(project compose.Project) (map[string]bool, []string) { +func getUsedRegistries(project compose.Project, ce *login.CloudEnvironment) (map[string]bool, []string) { usedRegistries := map[string]bool{} acrRegistries := []string{} + for _, service := range project.Services { imageName := service.Image tokens := strings.Split(imageName, "/") @@ -127,24 +138,18 @@ func getUsedRegistries(project compose.Project) (map[string]bool, []string) { registry = dockerHub } else if !strings.Contains(registry, ".") { registry = dockerHub - } else if strings.HasSuffix(registry, login.AcrRegistrySuffix) { - acrRegistries = append(acrRegistries, registry) + } else if ce != nil { + if suffix, present := ce.Suffixes[login.AcrSuffixKey]; present && strings.HasSuffix(registry, suffix) { + acrRegistries = append(acrRegistries, registry) + } } usedRegistries[registry] = true } return usedRegistries, acrRegistries } -func (c cliRegistryHelper) autoLoginAcr(registry string) error { - loginService, err := login.NewAzureLoginService() - if err != nil { - return err - } - token, err := loginService.GetValidToken() - if err != nil { - return err - } - tenantID, err := loginService.GetTenantID() +func (c cliRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error { + token, tenantID, err := loginService.GetValidToken() if err != nil { return err } diff --git a/aci/convert/registry_credentials_test.go b/aci/convert/registry_credentials_test.go index 890776632..78ca7556a 100644 --- a/aci/convert/registry_credentials_test.go +++ b/aci/convert/registry_credentials_test.go @@ -25,6 +25,7 @@ import ( "github.com/Azure/go-autorest/autorest/to" "github.com/compose-spec/compose-go/types" cliconfigtypes "github.com/docker/cli/cli/config/types" + "github.com/docker/compose-cli/aci/login" "github.com/stretchr/testify/mock" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -255,7 +256,7 @@ func (s *MockRegistryHelper) getAllRegistryCredentials() (map[string]cliconfigty return args.Get(0).(map[string]cliconfigtypes.AuthConfig), args.Error(1) } -func (s *MockRegistryHelper) autoLoginAcr(registry string) error { - args := s.Called(registry) +func (s *MockRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error { + args := s.Called(registry, loginService) return args.Error(0) } diff --git a/aci/login/client.go b/aci/login/client.go index 55b2fb979..0d6bafd4f 100644 --- a/aci/login/client.go +++ b/aci/login/client.go @@ -17,6 +17,8 @@ package login import ( + "encoding/json" + "strconv" "time" "github.com/Azure/azure-sdk-for-go/profiles/2019-03-01/resources/mgmt/resources" @@ -24,6 +26,8 @@ import ( "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2019-12-01/containerinstance" "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage" "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/date" "github.com/pkg/errors" "github.com/docker/compose-cli/api/errdefs" @@ -32,8 +36,12 @@ import ( // NewContainerGroupsClient get client toi manipulate containerGrouos func NewContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) { - containerGroupsClient := containerinstance.NewContainerGroupsClient(subscriptionID) - err := setupClient(&containerGroupsClient.Client) + authorizer, mgmtURL, err := getClientSetupData() + if err != nil { + return containerinstance.ContainerGroupsClient{}, err + } + containerGroupsClient := containerinstance.NewContainerGroupsClientWithBaseURI(mgmtURL, subscriptionID) + setupClient(&containerGroupsClient.Client, authorizer) if err != nil { return containerinstance.ContainerGroupsClient{}, err } @@ -43,68 +51,100 @@ func NewContainerGroupsClient(subscriptionID string) (containerinstance.Containe return containerGroupsClient, nil } -func setupClient(aciClient *autorest.Client) error { +func setupClient(aciClient *autorest.Client, auth autorest.Authorizer) { aciClient.UserAgent = internal.UserAgentName + "/" + internal.Version - auth, err := NewAuthorizerFromLogin() - if err != nil { - return err - } aciClient.Authorizer = auth - return nil } // NewStorageAccountsClient get client to manipulate storage accounts func NewStorageAccountsClient(subscriptionID string) (storage.AccountsClient, error) { - containerGroupsClient := storage.NewAccountsClient(subscriptionID) - err := setupClient(&containerGroupsClient.Client) + authorizer, mgmtURL, err := getClientSetupData() if err != nil { return storage.AccountsClient{}, err } - containerGroupsClient.PollingDelay = 5 * time.Second - containerGroupsClient.RetryAttempts = 30 - containerGroupsClient.RetryDuration = 1 * time.Second - return containerGroupsClient, nil + storageAccuntsClient := storage.NewAccountsClientWithBaseURI(mgmtURL, subscriptionID) + setupClient(&storageAccuntsClient.Client, authorizer) + storageAccuntsClient.PollingDelay = 5 * time.Second + storageAccuntsClient.RetryAttempts = 30 + storageAccuntsClient.RetryDuration = 1 * time.Second + return storageAccuntsClient, nil } // NewFileShareClient get client to manipulate file shares func NewFileShareClient(subscriptionID string) (storage.FileSharesClient, error) { - containerGroupsClient := storage.NewFileSharesClient(subscriptionID) - err := setupClient(&containerGroupsClient.Client) + authorizer, mgmtURL, err := getClientSetupData() if err != nil { return storage.FileSharesClient{}, err } - containerGroupsClient.PollingDelay = 5 * time.Second - containerGroupsClient.RetryAttempts = 30 - containerGroupsClient.RetryDuration = 1 * time.Second - return containerGroupsClient, nil + fileSharesClient := storage.NewFileSharesClientWithBaseURI(mgmtURL, subscriptionID) + setupClient(&fileSharesClient.Client, authorizer) + fileSharesClient.PollingDelay = 5 * time.Second + fileSharesClient.RetryAttempts = 30 + fileSharesClient.RetryDuration = 1 * time.Second + return fileSharesClient, nil } // NewSubscriptionsClient get subscription client func NewSubscriptionsClient() (subscription.SubscriptionsClient, error) { - subc := subscription.NewSubscriptionsClient() - err := setupClient(&subc.Client) + authorizer, mgmtURL, err := getClientSetupData() if err != nil { return subscription.SubscriptionsClient{}, errors.Wrap(errdefs.ErrLoginRequired, err.Error()) } + subc := subscription.NewSubscriptionsClientWithBaseURI(mgmtURL) + setupClient(&subc.Client, authorizer) return subc, nil } // NewGroupsClient get client to manipulate groups func NewGroupsClient(subscriptionID string) (resources.GroupsClient, error) { - groupsClient := resources.NewGroupsClient(subscriptionID) - err := setupClient(&groupsClient.Client) + authorizer, mgmtURL, err := getClientSetupData() if err != nil { return resources.GroupsClient{}, err } + groupsClient := resources.NewGroupsClientWithBaseURI(mgmtURL, subscriptionID) + setupClient(&groupsClient.Client, authorizer) return groupsClient, nil } // NewContainerClient get client to manipulate containers func NewContainerClient(subscriptionID string) (containerinstance.ContainersClient, error) { - containerClient := containerinstance.NewContainersClient(subscriptionID) - err := setupClient(&containerClient.Client) + authorizer, mgmtURL, err := getClientSetupData() if err != nil { return containerinstance.ContainersClient{}, err } + containerClient := containerinstance.NewContainersClientWithBaseURI(mgmtURL, subscriptionID) + setupClient(&containerClient.Client, authorizer) return containerClient, nil } + +func getClientSetupData() (autorest.Authorizer, string, error) { + return getClientSetupDataImpl(GetTokenStorePath()) +} + +func getClientSetupDataImpl(tokenStorePath string) (autorest.Authorizer, string, error) { + als, err := newAzureLoginServiceFromPath(tokenStorePath, azureAPIHelper{}, CloudEnvironments) + if err != nil { + return nil, "", err + } + + oauthToken, _, err := als.GetValidToken() + if err != nil { + return nil, "", errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first") + } + + ce, err := als.GetCloudEnvironment() + if err != nil { + return nil, "", err + } + + token := adal.Token{ + AccessToken: oauthToken.AccessToken, + Type: oauthToken.TokenType, + ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))), + ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))), + RefreshToken: "", + Resource: "", + } + + return autorest.NewBearerAuthorizer(&token), ce.ResourceManagerURL, nil +} diff --git a/aci/login/client_test.go b/aci/login/client_test.go new file mode 100644 index 000000000..4c77eec14 --- /dev/null +++ b/aci/login/client_test.go @@ -0,0 +1,36 @@ +/* + Copyright 2020 Docker Compose CLI authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package login + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "gotest.tools/v3/assert" +) + +func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) { + dir, err := ioutil.TempDir("", "test_store") + assert.NilError(t, err) + t.Cleanup(func() { + _ = os.RemoveAll(dir) + }) + _, _, err = getClientSetupDataImpl(filepath.Join(dir, tokenStoreFilename)) + assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first") +} diff --git a/aci/login/cloud_environment.go b/aci/login/cloud_environment.go new file mode 100644 index 000000000..d02bef6fc --- /dev/null +++ b/aci/login/cloud_environment.go @@ -0,0 +1,274 @@ +/* + Copyright 2020 Docker Compose CLI authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package login + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "strings" + + "github.com/pkg/errors" +) + +const ( + // AzurePublicCloudName is the moniker of the Azure public cloud + AzurePublicCloudName = "AzureCloud" + + // AcrSuffixKey is the well-known name of the DNS suffix for Azure Container Registries + AcrSuffixKey = "acrLoginServer" + + // CloudMetadataURLVar is the name of the environment variable that (if defined), points to a URL that should be used by Docker CLI to retrieve cloud metadata + CloudMetadataURLVar = "ARM_CLOUD_METADATA_URL" + + // DefaultCloudMetadataURL is the URL of the cloud metadata service maintained by Azure public cloud + DefaultCloudMetadataURL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01" +) + +// CloudEnvironmentService exposed metadata about Azure cloud environments +type CloudEnvironmentService interface { + Get(name string) (CloudEnvironment, error) +} + +type cloudEnvironmentService struct { + cloudEnvironments map[string]CloudEnvironment + cloudMetadataURL string + // True if we have queried the cloud metadata endpoint already. + // We do it only once per CLI invocation. + metadataQueried bool +} + +var ( + // CloudEnvironments is the default instance of the CloudEnvironmentService + CloudEnvironments CloudEnvironmentService +) + +func init() { + CloudEnvironments = newCloudEnvironmentService() +} + +// CloudEnvironmentAuthentication data for logging into, and obtaining tokens for, Azure sovereign clouds +type CloudEnvironmentAuthentication struct { + LoginEndpoint string `json:"loginEndpoint"` + Audiences []string `json:"audiences"` + Tenant string `json:"tenant"` +} + +// CloudEnvironment describes Azure sovereign cloud instance (e.g. Azure public, Azure US government, Azure China etc.) +type CloudEnvironment struct { + Name string `json:"name"` + Authentication CloudEnvironmentAuthentication `json:"authentication"` + ResourceManagerURL string `json:"resourceManager"` + Suffixes map[string]string `json:"suffixes"` +} + +func newCloudEnvironmentService() *cloudEnvironmentService { + retval := cloudEnvironmentService{ + metadataQueried: false, + } + retval.resetCloudMetadata() + return &retval +} + +func (ces *cloudEnvironmentService) Get(name string) (CloudEnvironment, error) { + if ce, present := ces.cloudEnvironments[name]; present { + return ce, nil + } + + if !ces.metadataQueried { + ces.metadataQueried = true + + if ces.cloudMetadataURL == "" { + ces.cloudMetadataURL = os.Getenv(CloudMetadataURLVar) + if _, err := url.ParseRequestURI(ces.cloudMetadataURL); err != nil { + ces.cloudMetadataURL = DefaultCloudMetadataURL + } + } + + res, err := http.Get(ces.cloudMetadataURL) + if err != nil { + return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err) + } + if res.StatusCode != 200 { + return CloudEnvironment{}, errors.Errorf("Cloud metadata retrieval from '%s' failed: server response was '%s'", ces.cloudMetadataURL, res.Status) + } + + bytes, err := ioutil.ReadAll(res.Body) + if err != nil { + return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err) + } + + if err = ces.applyCloudMetadata(bytes); err != nil { + return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err) + } + } + + if ce, present := ces.cloudEnvironments[name]; present { + return ce, nil + } + + return CloudEnvironment{}, errors.Errorf("Cloud environment '%s' was not found", name) +} + +func (ces *cloudEnvironmentService) applyCloudMetadata(jsonBytes []byte) error { + input := []CloudEnvironment{} + if err := json.Unmarshal(jsonBytes, &input); err != nil { + return err + } + + newEnvironments := make(map[string]CloudEnvironment, len(input)) + // If _any_ of the submitted data is invalid, we bail out. + for _, e := range input { + if len(e.Name) == 0 { + return errors.New("Azure cloud environment metadata is invalid (an environment with no name has been encountered)") + } + + e.normalizeURLs() + + if _, err := url.ParseRequestURI(e.Authentication.LoginEndpoint); err != nil { + return errors.Errorf("Metadata of cloud environment '%s' has invalid login endpoint URL: %s", e.Name, e.Authentication.LoginEndpoint) + } + + if _, err := url.ParseRequestURI(e.ResourceManagerURL); err != nil { + return errors.Errorf("Metadata of cloud environment '%s' has invalid resource manager URL: %s", e.Name, e.ResourceManagerURL) + } + + if len(e.Authentication.Audiences) == 0 { + return errors.Errorf("Metadata of cloud environment '%s' is invalid (no authentication audiences)", e.Name) + } + + newEnvironments[e.Name] = e + } + + for name, e := range newEnvironments { + ces.cloudEnvironments[name] = e + } + return nil +} + +func (ces *cloudEnvironmentService) resetCloudMetadata() { + azurePublicCloud := CloudEnvironment{ + Name: AzurePublicCloudName, + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.microsoftonline.com", + Audiences: []string{ + "https://management.core.windows.net", + "https://management.azure.com", + }, + Tenant: "common", + }, + ResourceManagerURL: "https://management.azure.com", + Suffixes: map[string]string{ + AcrSuffixKey: "azurecr.io", + }, + } + + azureChinaCloud := CloudEnvironment{ + Name: "AzureChinaCloud", + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.chinacloudapi.cn", + Audiences: []string{ + "https://management.core.chinacloudapi.cn", + "https://management.chinacloudapi.cn", + }, + Tenant: "common", + }, + ResourceManagerURL: "https://management.chinacloudapi.cn", + Suffixes: map[string]string{ + AcrSuffixKey: "azurecr.cn", + }, + } + + azureUSGovernment := CloudEnvironment{ + Name: "AzureUSGovernment", + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.microsoftonline.us", + Audiences: []string{ + "https://management.core.usgovcloudapi.net", + "https://management.usgovcloudapi.net", + }, + Tenant: "common", + }, + ResourceManagerURL: "https://management.usgovcloudapi.net", + Suffixes: map[string]string{ + AcrSuffixKey: "azurecr.us", + }, + } + + azureGermanCloud := CloudEnvironment{ + Name: "AzureGermanCloud", + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.microsoftonline.de", + Audiences: []string{ + "https://management.core.cloudapi.de", + "https://management.microsoftazure.de", + }, + Tenant: "common", + }, + ResourceManagerURL: "https://management.microsoftazure.de", + + // There is no separate container registry suffix for German cloud + Suffixes: map[string]string{}, + } + + ces.cloudEnvironments = map[string]CloudEnvironment{ + azurePublicCloud.Name: azurePublicCloud, + azureChinaCloud.Name: azureChinaCloud, + azureUSGovernment.Name: azureUSGovernment, + azureGermanCloud.Name: azureGermanCloud, + } +} + +// GetTenantQueryURL returns an URL that can be used to fetch the list of Azure Active Directory tenants from a given cloud environment +func (ce *CloudEnvironment) GetTenantQueryURL() string { + tenantURL := ce.ResourceManagerURL + "/tenants?api-version=2019-11-01" + return tenantURL +} + +// GetTokenScope returns a token scope that fits Docker CLI Azure management API usage +func (ce *CloudEnvironment) GetTokenScope() string { + scope := "offline_access " + ce.ResourceManagerURL + "/.default" + return scope +} + +// GetAuthorizeRequestFormat returns a string format that can be used to construct authorization code request in an OAuth2 flow. +// The URL uses login endpoint appropriate for given cloud environment. +func (ce *CloudEnvironment) GetAuthorizeRequestFormat() string { + authorizeFormat := ce.Authentication.LoginEndpoint + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s" + return authorizeFormat +} + +// GetTokenRequestFormat returns a string format that can be used to construct a security token request against Azure Active Directory +func (ce *CloudEnvironment) GetTokenRequestFormat() string { + tokenEndpoint := ce.Authentication.LoginEndpoint + "/%s/oauth2/v2.0/token" + return tokenEndpoint +} + +func (ce *CloudEnvironment) normalizeURLs() { + ce.ResourceManagerURL = removeTrailingSlash(ce.ResourceManagerURL) + ce.Authentication.LoginEndpoint = removeTrailingSlash(ce.Authentication.LoginEndpoint) + for i, s := range ce.Authentication.Audiences { + ce.Authentication.Audiences[i] = removeTrailingSlash(s) + } +} + +func removeTrailingSlash(s string) string { + return strings.TrimSuffix(s, "/") +} diff --git a/aci/login/cloud_environment_test.go b/aci/login/cloud_environment_test.go new file mode 100644 index 000000000..57577761a --- /dev/null +++ b/aci/login/cloud_environment_test.go @@ -0,0 +1,187 @@ +/* + Copyright 2020 Docker Compose CLI authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package login + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestNormalizeCloudEnvironmentURLs(t *testing.T) { + ce := CloudEnvironment{ + Name: "SecretCloud", + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.here.com/", + Audiences: []string{ + "https://audience1", + "https://audience2/", + }, + Tenant: "common", + }, + ResourceManagerURL: "invalid URL", + } + + ce.normalizeURLs() + + assert.Equal(t, ce.Authentication.LoginEndpoint, "https://login.here.com") + assert.Equal(t, ce.Authentication.Audiences[0], "https://audience1") + assert.Equal(t, ce.Authentication.Audiences[1], "https://audience2") +} + +func TestApplyInvalidCloudMetadataJSON(t *testing.T) { + ce := newCloudEnvironmentService() + bb := []byte(`This isn't really valid JSON`) + + err := ce.applyCloudMetadata(bb) + + assert.Assert(t, err != nil, "Cloud metadata was invalid, so the application should have failed") + ensureWellKnownCloudMetadata(t, ce) +} + +func TestApplyInvalidCloudMetatada(t *testing.T) { + ce := newCloudEnvironmentService() + + // No name (moniker) for the cloud + bb := []byte(` + [{ + "authentication": { + "loginEndpoint": "https://login.docker.com/", + "audiences": [ + "https://management.docker.com/", + "https://management.cli.docker.com/" + ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "https://management.docker.com/" + }]`) + + err := ce.applyCloudMetadata(bb) + assert.ErrorContains(t, err, "no name") + ensureWellKnownCloudMetadata(t, ce) + + // Invalid resource manager URL + bb = []byte(` + [{ + "authentication": { + "loginEndpoint": "https://login.docker.com/", + "audiences": [ + "https://management.docker.com/", + "https://management.cli.docker.com/" + ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "name": "DockerAzureCloud", + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "123" + }]`) + + err = ce.applyCloudMetadata(bb) + assert.ErrorContains(t, err, "invalid resource manager URL") + ensureWellKnownCloudMetadata(t, ce) + + // Invalid login endpoint + bb = []byte(` + [{ + "authentication": { + "loginEndpoint": "456", + "audiences": [ + "https://management.docker.com/", + "https://management.cli.docker.com/" + ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "name": "DockerAzureCloud", + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "https://management.docker.com/" + }]`) + + err = ce.applyCloudMetadata(bb) + assert.ErrorContains(t, err, "invalid login endpoint") + ensureWellKnownCloudMetadata(t, ce) + + // No audiences + bb = []byte(` + [{ + "authentication": { + "loginEndpoint": "https://login.docker.com/", + "audiences": [ ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "name": "DockerAzureCloud", + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "https://management.docker.com/" + }]`) + + err = ce.applyCloudMetadata(bb) + assert.ErrorContains(t, err, "no authentication audiences") + ensureWellKnownCloudMetadata(t, ce) +} + +func TestApplyCloudMetadata(t *testing.T) { + ce := newCloudEnvironmentService() + + bb := []byte(` + [{ + "authentication": { + "loginEndpoint": "https://login.docker.com/", + "audiences": [ + "https://management.docker.com/", + "https://management.cli.docker.com/" + ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "name": "DockerAzureCloud", + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "https://management.docker.com/" + }]`) + + err := ce.applyCloudMetadata(bb) + assert.NilError(t, err) + + env, err := ce.Get("DockerAzureCloud") + assert.NilError(t, err) + assert.Equal(t, env.Authentication.LoginEndpoint, "https://login.docker.com") + ensureWellKnownCloudMetadata(t, ce) +} + +func TestDefaultCloudMetadataPresent(t *testing.T) { + ensureWellKnownCloudMetadata(t, CloudEnvironments) +} + +func ensureWellKnownCloudMetadata(t *testing.T, ce CloudEnvironmentService) { + // Make sure well-known public cloud information is still available + _, err := ce.Get(AzurePublicCloudName) + assert.NilError(t, err) + + _, err = ce.Get("AzureChinaCloud") + assert.NilError(t, err) + + _, err = ce.Get("AzureUSGovernment") + assert.NilError(t, err) +} diff --git a/aci/login/helper.go b/aci/login/helper.go index 4d20624a5..2a5ec13f4 100644 --- a/aci/login/helper.go +++ b/aci/login/helper.go @@ -39,17 +39,17 @@ var ( ) type apiHelper interface { - queryToken(data url.Values, tenantID string) (azureToken, error) - openAzureLoginPage(redirectURL string) error + queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error) + openAzureLoginPage(redirectURL string, ce CloudEnvironment) error queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) - getDeviceCodeFlowToken() (adal.Token, error) + getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) } type azureAPIHelper struct{} -func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) { +func (helper azureAPIHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) { deviceconfig := auth.NewDeviceFlowConfig(clientID, "common") - deviceconfig.Resource = azureManagementURL + deviceconfig.Resource = ce.ResourceManagerURL spToken, err := deviceconfig.ServicePrincipalToken() if err != nil { return adal.Token{}, err @@ -57,9 +57,9 @@ func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) { return spToken.Token(), err } -func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error { +func (helper azureAPIHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error { state := randomString("", 10) - authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes) + authURL := fmt.Sprintf(ce.GetAuthorizeRequestFormat(), clientID, redirectURL, state, ce.GetTokenScope()) return openbrowser(authURL) } @@ -81,8 +81,8 @@ func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizati return bits, res.StatusCode, nil } -func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) { - res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +func (helper azureAPIHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error) { + res, err := http.Post(fmt.Sprintf(ce.GetTokenRequestFormat(), tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) if err != nil { return azureToken{}, err } diff --git a/aci/login/login.go b/aci/login/login.go index 21cb30fcd..805b84239 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -23,13 +23,10 @@ import ( "net/http" "net/url" "os" - "strconv" "time" - "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure/auth" - "github.com/Azure/go-autorest/autorest/date" "github.com/pkg/errors" "golang.org/x/oauth2" @@ -38,18 +35,6 @@ import ( //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff const ( - // AcrRegistrySuffix suffix for ACR registry images - AcrRegistrySuffix = ".azurecr.io" - activeDirectoryURL = "https://login.microsoftonline.com" - azureManagementURL = "https://management.core.windows.net/" - azureResouceManagementURL = "https://management.azure.com/" - authorizeFormat = activeDirectoryURL + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s" - tokenEndpoint = activeDirectoryURL + "/%s/oauth2/v2.0/token" - getTenantURL = azureResouceManagementURL + "tenants?api-version=2019-11-01" - - // scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token - // v1 scope like "https://management.azure.com/.default" for ARM access - scopes = "offline_access " + azureResouceManagementURL + ".default" clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id ) @@ -73,39 +58,41 @@ type ( ) // AzureLoginService Service to log into azure and get authentifier for azure APIs -type AzureLoginService struct { - tokenStore tokenStore - apiHelper apiHelper -} - -// AzureLoginServiceAPI interface for Azure login service -type AzureLoginServiceAPI interface { - LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error - Login(ctx context.Context, requestedTenantID string) error +type AzureLoginService interface { + Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error + LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error Logout(ctx context.Context) error + GetCloudEnvironment() (CloudEnvironment, error) + GetValidToken() (oauth2.Token, string, error) +} +type azureLoginService struct { + tokenStore tokenStore + apiHelper apiHelper + cloudEnvironmentSvc CloudEnvironmentService } const tokenStoreFilename = "dockerAccessToken.json" // NewAzureLoginService creates a NewAzureLoginService -func NewAzureLoginService() (*AzureLoginService, error) { - return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}) +func NewAzureLoginService() (AzureLoginService, error) { + return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}, CloudEnvironments) } -func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*AzureLoginService, error) { +func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper, ces CloudEnvironmentService) (*azureLoginService, error) { store, err := newTokenStore(tokenStorePath) if err != nil { return nil, err } - return &AzureLoginService{ - tokenStore: store, - apiHelper: helper, + return &azureLoginService{ + tokenStore: store, + apiHelper: helper, + cloudEnvironmentSvc: ces, }, nil } // LoginServicePrincipal login with clientId / clientSecret from a service principal. // The resulting token does not include a refresh token -func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { +func (login *azureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error { // Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID) @@ -121,7 +108,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "could not read service principal token expiry: %s", err) } - loginInfo := TokenInfo{TenantID: tenantID, Token: token} + loginInfo := TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: cloudEnvironment} if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) @@ -130,7 +117,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec } // Logout remove azure token data -func (login *AzureLoginService) Logout(ctx context.Context) error { +func (login *azureLoginService) Logout(ctx context.Context) error { err := login.tokenStore.removeData() if os.IsNotExist(err) { return errors.New("No Azure login data to be removed") @@ -138,8 +125,14 @@ func (login *AzureLoginService) Logout(ctx context.Context) error { return err } -func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error { - bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) +func (login *azureLoginService) getTenantAndValidateLogin( + ctx context.Context, + accessToken string, + refreshToken string, + requestedTenantID string, + ce CloudEnvironment, +) error { + bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, ce.GetTenantQueryURL(), fmt.Sprintf("Bearer %s", accessToken)) if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) } @@ -155,11 +148,11 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a if err != nil { return errors.Wrap(errdefs.ErrLoginFailed, err.Error()) } - tToken, err := login.refreshToken(refreshToken, tenantID) + tToken, err := login.refreshToken(refreshToken, tenantID, ce) if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err) } - loginInfo := TokenInfo{TenantID: tenantID, Token: tToken} + loginInfo := TokenInfo{TenantID: tenantID, Token: tToken, CloudEnvironment: ce.Name} if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) @@ -168,7 +161,12 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a } // Login performs an Azure login through a web browser -func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error { +func (login *azureLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error { + ce, err := login.cloudEnvironmentSvc.Get(cloudEnvironment) + if err != nil { + return err + } + queryCh := make(chan localResponse, 1) s, err := NewLocalServer(queryCh) if err != nil { @@ -183,8 +181,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str } deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1) - if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil { - login.startDeviceCodeFlow(deviceCodeFlowCh) + if err = login.apiHelper.openAzureLoginPage(redirectURL, ce); err != nil { + login.startDeviceCodeFlow(deviceCodeFlowCh, ce) } select { @@ -195,7 +193,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) } token := dcft.token - return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce) case q := <-queryCh: if q.err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) @@ -208,14 +206,14 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": code, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, } - token, err := login.apiHelper.queryToken(data, "organizations") + token, err := login.apiHelper.queryToken(ce, data, "organizations") if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) } - return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce) } } @@ -224,10 +222,10 @@ type deviceCodeFlowResponse struct { err error } -func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) { +func (login *azureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse, ce CloudEnvironment) { fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication") go func() { - token, err := login.apiHelper.getDeviceCodeFlowToken() + token, err := login.apiHelper.getDeviceCodeFlowToken(ce) if err != nil { deviceCodeFlowCh <- deviceCodeFlowResponse{err: err} } @@ -276,72 +274,58 @@ func spToOAuthToken(token adal.Token) (oauth2.Token, error) { return oauthToken, nil } -// NewAuthorizerFromLogin creates an authorizer based on login access token -func NewAuthorizerFromLogin() (autorest.Authorizer, error) { - return newAuthorizerFromLoginStorePath(GetTokenStorePath()) -} - -func newAuthorizerFromLoginStorePath(storeTokenPath string) (autorest.Authorizer, error) { - login, err := newAzureLoginServiceFromPath(storeTokenPath, azureAPIHelper{}) - if err != nil { - return nil, err - } - oauthToken, err := login.GetValidToken() - if err != nil { - return nil, errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first") - } - - token := adal.Token{ - AccessToken: oauthToken.AccessToken, - Type: oauthToken.TokenType, - ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))), - ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))), - RefreshToken: "", - Resource: "", - } - - return autorest.NewBearerAuthorizer(&token), nil -} - -// GetTenantID returns tenantID for current login -func (login AzureLoginService) GetTenantID() (string, error) { +// GetValidToken returns an access token and associated tenant ID. +// Will refresh the token as necessary. +func (login *azureLoginService) GetValidToken() (oauth2.Token, string, error) { loginInfo, err := login.tokenStore.readToken() if err != nil { - return "", err - } - return loginInfo.TenantID, err -} - -// GetValidToken returns an access token. Refresh token if needed -func (login *AzureLoginService) GetValidToken() (oauth2.Token, error) { - loginInfo, err := login.tokenStore.readToken() - if err != nil { - return oauth2.Token{}, err + return oauth2.Token{}, "", err } token := loginInfo.Token - if token.Valid() { - return token, nil - } tenantID := loginInfo.TenantID - token, err = login.refreshToken(token.RefreshToken, tenantID) - if err != nil { - return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.") + if token.Valid() { + return token, tenantID, nil } - err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token}) + + ce, err := login.cloudEnvironmentSvc.Get(loginInfo.CloudEnvironment) if err != nil { - return oauth2.Token{}, err + return oauth2.Token{}, "", errors.Wrap(err, "access token request failed--cloud environment could not be determined.") } - return token, nil + + token, err = login.refreshToken(token.RefreshToken, tenantID, ce) + if err != nil { + return oauth2.Token{}, "", errors.Wrap(err, "access token request failed. Maybe you need to login to Azure again.") + } + err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: ce.Name}) + if err != nil { + return oauth2.Token{}, "", err + } + return token, tenantID, nil } -func (login *AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) { +// GeCloudEnvironment returns the cloud environment associated with the current authentication token (if we have one) +func (login *azureLoginService) GetCloudEnvironment() (CloudEnvironment, error) { + tokenInfo, err := login.tokenStore.readToken() + if err != nil { + return CloudEnvironment{}, err + } + + cloudEnvironment, err := login.cloudEnvironmentSvc.Get(tokenInfo.CloudEnvironment) + if err != nil { + return CloudEnvironment{}, err + } + + return cloudEnvironment, nil +} + +func (login *azureLoginService) refreshToken(currentRefreshToken string, tenantID string, ce CloudEnvironment) (oauth2.Token, error) { data := url.Values{ "grant_type": []string{"refresh_token"}, "client_id": []string{clientID}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "refresh_token": []string{currentRefreshToken}, } - token, err := login.apiHelper.queryToken(data, tenantID) + token, err := login.apiHelper.queryToken(ce, data, tenantID) if err != nil { return oauth2.Token{}, err } diff --git a/aci/login/login_test.go b/aci/login/login_test.go index 12330fe93..c1d7a3441 100644 --- a/aci/login/login_test.go +++ b/aci/login/login_test.go @@ -21,10 +21,12 @@ import ( "errors" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "os" "path/filepath" "reflect" + "sync/atomic" "testing" "time" @@ -36,7 +38,7 @@ import ( "golang.org/x/oauth2" ) -func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, error) { +func testLoginService(t *testing.T, apiHelperMock *MockAzureHelper, cloudEnvironmentSvc CloudEnvironmentService) (*azureLoginService, error) { dir, err := ioutil.TempDir("", "test_store") if err != nil { return nil, err @@ -44,20 +46,45 @@ func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, err t.Cleanup(func() { _ = os.RemoveAll(dir) }) - return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), m) + + ces := CloudEnvironments + if cloudEnvironmentSvc != nil { + ces = cloudEnvironmentSvc + } + return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), apiHelperMock, ces) } func TestRefreshInValidToken(t *testing.T) { - data := refreshTokenData("refreshToken") - m := &MockAzureHelper{} - m.On("queryToken", data, "123456").Return(azureToken{ + data := url.Values{ + "grant_type": []string{"refresh_token"}, + "client_id": []string{clientID}, + "scope": []string{"offline_access https://management.docker.com/.default"}, + "refresh_token": []string{"refreshToken"}, + } + helperMock := &MockAzureHelper{} + helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "123456").Return(azureToken{ RefreshToken: "newRefreshToken", AccessToken: "newAccessToken", ExpiresIn: 3600, Foci: "1", }, nil) - azureLogin, err := testLoginService(t, m) + cloudEnvironmentSvcMock := &MockCloudEnvironmentService{} + cloudEnvironmentSvcMock.On("Get", "AzureDockerCloud").Return(CloudEnvironment{ + Name: "AzureDockerCloud", + Authentication: CloudEnvironmentAuthentication{ + LoginEndpoint: "https://login.docker.com", + Audiences: []string{ + "https://management.docker.com", + "https://management-ext.docker.com", + }, + Tenant: "common", + }, + ResourceManagerURL: "https://management.docker.com", + Suffixes: map[string]string{}, + }, nil) + + azureLogin, err := testLoginService(t, helperMock, cloudEnvironmentSvcMock) assert.NilError(t, err) err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{ TenantID: "123456", @@ -67,33 +94,29 @@ func TestRefreshInValidToken(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), TokenType: "Bearer", }, + CloudEnvironment: "AzureDockerCloud", }) assert.NilError(t, err) - token, _ := azureLogin.GetValidToken() + token, tenantID, err := azureLogin.GetValidToken() + assert.NilError(t, err) + assert.Equal(t, tenantID, "123456") assert.Equal(t, token.AccessToken, "newAccessToken") assert.Assert(t, time.Now().Add(3500*time.Second).Before(token.Expiry)) - storedToken, _ := azureLogin.tokenStore.readToken() + storedToken, err := azureLogin.tokenStore.readToken() + assert.NilError(t, err) assert.Equal(t, storedToken.Token.AccessToken, "newAccessToken") assert.Equal(t, storedToken.Token.RefreshToken, "newRefreshToken") assert.Assert(t, time.Now().Add(3500*time.Second).Before(storedToken.Token.Expiry)) -} -func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) { - dir, err := ioutil.TempDir("", "test_store") - assert.NilError(t, err) - t.Cleanup(func() { - _ = os.RemoveAll(dir) - }) - _, err = newAuthorizerFromLoginStorePath(filepath.Join(dir, tokenStoreFilename)) - assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first") + assert.Equal(t, storedToken.CloudEnvironment, "AzureDockerCloud") } func TestDoesNotRefreshValidToken(t *testing.T) { expiryDate := time.Now().Add(1 * time.Hour) - azureLogin, err := testLoginService(t, nil) + azureLogin, err := testLoginService(t, nil, nil) assert.NilError(t, err) err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{ TenantID: "123456", @@ -103,25 +126,55 @@ func TestDoesNotRefreshValidToken(t *testing.T) { Expiry: expiryDate, TokenType: "Bearer", }, + CloudEnvironment: AzurePublicCloudName, }) assert.NilError(t, err) - token, _ := azureLogin.GetValidToken() + token, tenantID, err := azureLogin.GetValidToken() + assert.NilError(t, err) assert.Equal(t, token.AccessToken, "accessToken") + assert.Equal(t, tenantID, "123456") +} + +func TestTokenStoreAssumesAzurePublicCloud(t *testing.T) { + expiryDate := time.Now().Add(1 * time.Hour) + azureLogin, err := testLoginService(t, nil, nil) + assert.NilError(t, err) + err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{ + TenantID: "123456", + Token: oauth2.Token{ + AccessToken: "accessToken", + RefreshToken: "refreshToken", + Expiry: expiryDate, + TokenType: "Bearer", + }, + // Simulates upgrade from older version of Docker CLI that did not have cloud environment concept + CloudEnvironment: "", + }) + assert.NilError(t, err) + + token, tenantID, err := azureLogin.GetValidToken() + assert.NilError(t, err) + assert.Equal(t, tenantID, "123456") + assert.Equal(t, token.AccessToken, "accessToken") + + ce, err := azureLogin.GetCloudEnvironment() + assert.NilError(t, err) + assert.Equal(t, ce.Name, AzurePublicCloudName) } func TestInvalidLogin(t *testing.T) { m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL := args.Get(0).(string) err := queryKeyValue(redirectURL, "error", "access denied: login failed") assert.NilError(t, err) }).Return(nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(context.TODO(), "", AzurePublicCloudName) assert.Error(t, err, "no login code: login failed") } @@ -129,19 +182,22 @@ func TestValidLogin(t *testing.T) { var redirectURL string ctx := context.TODO() m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) }).Return(nil) - m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage return reflect.DeepEqual(data, url.Values{ "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": []string{"123456879"}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, }) }), "organizations").Return(azureToken{ @@ -153,18 +209,18 @@ func TestValidLogin(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) - data := refreshTokenData("firstRefreshToken") - m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken", ce) + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", AccessToken: "newAccessToken", ExpiresIn: 3600, Foci: "1", }, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "") + err = azureLogin.Login(ctx, "", AzurePublicCloudName) assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -174,24 +230,28 @@ func TestValidLogin(t *testing.T) { assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.Token.Type(), "Bearer") + assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud") } func TestValidLoginRequestedTenant(t *testing.T) { var redirectURL string m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) }).Return(nil) - m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage return reflect.DeepEqual(data, url.Values{ "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": []string{"123456879"}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, }) }), "organizations").Return(azureToken{ @@ -205,18 +265,18 @@ func TestValidLoginRequestedTenant(t *testing.T) { {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` ctx := context.TODO() - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) - data := refreshTokenData("firstRefreshToken") - m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken", ce) + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", AccessToken: "newAccessToken", ExpiresIn: 3600, Foci: "1", }, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName) assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -226,24 +286,28 @@ func TestValidLoginRequestedTenant(t *testing.T) { assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.Token.Type(), "Bearer") + assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud") } func TestLoginNoTenant(t *testing.T) { var redirectURL string m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) }).Return(nil) - m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage return reflect.DeepEqual(data, url.Values{ "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": []string{"123456879"}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, }) }), "organizations").Return(azureToken{ @@ -255,31 +319,34 @@ func TestLoginNoTenant(t *testing.T) { ctx := context.TODO() authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName) assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed") } func TestLoginRequestedTenantNotFound(t *testing.T) { var redirectURL string m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) }).Return(nil) - m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage return reflect.DeepEqual(data, url.Values{ "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": []string{"123456879"}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, }) }), "organizations").Return(azureToken{ @@ -291,31 +358,34 @@ func TestLoginRequestedTenantNotFound(t *testing.T) { ctx := context.TODO() authBody := `{"value":[]}` - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "") + err = azureLogin.Login(ctx, "", AzurePublicCloudName) assert.Error(t, err, "could not find azure tenant: login failed") } func TestLoginAuthorizationFailed(t *testing.T) { var redirectURL string m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) }).Return(nil) - m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage return reflect.DeepEqual(data, url.Values{ "grant_type": []string{"authorization_code"}, "client_id": []string{clientID}, "code": []string{"123456879"}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "redirect_uri": []string{redirectURL}, }) }), "organizations").Return(azureToken{ @@ -328,35 +398,38 @@ func TestLoginAuthorizationFailed(t *testing.T) { authBody := `[access denied]` ctx := context.TODO() - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "") + err = azureLogin.Login(ctx, "", AzurePublicCloudName) assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") } func TestValidThroughDeviceCodeFlow(t *testing.T) { m := &MockAzureHelper{} - m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser")) - m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil) + ce, err := CloudEnvironments.Get(AzurePublicCloudName) + assert.NilError(t, err) + + m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Return(errors.New("Could not open browser")) + m.On("getDeviceCodeFlowToken", mock.AnythingOfType("CloudEnvironment")).Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil) authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` ctx := context.TODO() - m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) - data := refreshTokenData("firstRefreshToken") - m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ + m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken", ce) + m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", AccessToken: "newAccessToken", ExpiresIn: 3600, Foci: "1", }, nil) - azureLogin, err := testLoginService(t, m) + azureLogin, err := testLoginService(t, m, nil) assert.NilError(t, err) - err = azureLogin.Login(ctx, "") + err = azureLogin.Login(ctx, "", AzurePublicCloudName) assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -366,13 +439,110 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) { assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.Token.Type(), "Bearer") + assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud") } -func refreshTokenData(refreshToken string) url.Values { +func TestNonstandardCloudEnvironment(t *testing.T) { + dockerCloudMetadata := []byte(` + [{ + "authentication": { + "loginEndpoint": "https://login.docker.com/", + "audiences": [ + "https://management.docker.com/", + "https://management.cli.docker.com/" + ], + "tenant": "F5773994-FE88-482E-9E33-6E799D250416" + }, + "name": "AzureDockerCloud", + "suffixes": { + "acrLoginServer": "azurecr.docker.io" + }, + "resourceManager": "https://management.docker.com/" + }]`) + var metadataReqCount int32 = 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write(dockerCloudMetadata) + assert.NilError(t, err) + atomic.AddInt32(&metadataReqCount, 1) + })) + defer srv.Close() + + cloudMetadataURL, cloudMetadataURLSet := os.LookupEnv(CloudMetadataURLVar) + if cloudMetadataURLSet { + defer func() { + err := os.Setenv(CloudMetadataURLVar, cloudMetadataURL) + assert.NilError(t, err) + }() + } + err := os.Setenv(CloudMetadataURLVar, srv.URL) + assert.NilError(t, err) + + ctx := context.TODO() + + ces := newCloudEnvironmentService() + ces.cloudMetadataURL = srv.URL + dockerCloudEnv, err := ces.Get("AzureDockerCloud") + assert.NilError(t, err) + + helperMock := &MockAzureHelper{} + var redirectURL string + helperMock.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) { + redirectURL = args.Get(0).(string) + err := queryKeyValue(redirectURL, "code", "123456879") + assert.NilError(t, err) + }).Return(nil) + + helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool { + //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage + return reflect.DeepEqual(data, url.Values{ + "grant_type": []string{"authorization_code"}, + "client_id": []string{clientID}, + "code": []string{"123456879"}, + "scope": []string{dockerCloudEnv.GetTokenScope()}, + "redirect_uri": []string{redirectURL}, + }) + }), "organizations").Return(azureToken{ + RefreshToken: "firstRefreshToken", + AccessToken: "firstAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + + authBody := `{"value":[{"id":"/tenants/F5773994-FE88-482E-9E33-6E799D250416","tenantId":"F5773994-FE88-482E-9E33-6E799D250416"}]}` + + helperMock.On("queryAPIWithHeader", ctx, dockerCloudEnv.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken", dockerCloudEnv) + helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "F5773994-FE88-482E-9E33-6E799D250416").Return(azureToken{ + RefreshToken: "newRefreshToken", + AccessToken: "newAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + + azureLogin, err := testLoginService(t, helperMock, ces) + assert.NilError(t, err) + + err = azureLogin.Login(ctx, "", "AzureDockerCloud") + assert.NilError(t, err) + + loginToken, err := azureLogin.tokenStore.readToken() + assert.NilError(t, err) + assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken") + assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken") + assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) + assert.Equal(t, loginToken.TenantID, "F5773994-FE88-482E-9E33-6E799D250416") + assert.Equal(t, loginToken.Token.Type(), "Bearer") + assert.Equal(t, loginToken.CloudEnvironment, "AzureDockerCloud") + assert.Equal(t, metadataReqCount, int32(1)) +} + +// Don't warn about refreshToken parameter taking the same value for all invocations +// nolint:unparam +func refreshTokenData(refreshToken string, ce CloudEnvironment) url.Values { return url.Values{ "grant_type": []string{"refresh_token"}, "client_id": []string{clientID}, - "scope": []string{scopes}, + "scope": []string{ce.GetTokenScope()}, "refresh_token": []string{refreshToken}, } } @@ -394,13 +564,13 @@ type MockAzureHelper struct { mock.Mock } -func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) { - args := s.Called() +func (s *MockAzureHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) { + args := s.Called(ce) return args.Get(0).(adal.Token), args.Error(1) } -func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) { - args := s.Called(data, tenantID) +func (s *MockAzureHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (token azureToken, err error) { + args := s.Called(ce, data, tenantID) return args.Get(0).(azureToken), args.Error(1) } @@ -409,7 +579,16 @@ func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationU return args.Get(0).([]byte), args.Int(1), args.Error(2) } -func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error { - args := s.Called(redirectURL) +func (s *MockAzureHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error { + args := s.Called(redirectURL, ce) return args.Error(0) } + +type MockCloudEnvironmentService struct { + mock.Mock +} + +func (s *MockCloudEnvironmentService) Get(name string) (CloudEnvironment, error) { + args := s.Called(name) + return args.Get(0).(CloudEnvironment), args.Error(1) +} diff --git a/aci/login/token_store.go b/aci/login/token_store.go index 337e9b931..131006888 100644 --- a/aci/login/token_store.go +++ b/aci/login/token_store.go @@ -34,8 +34,9 @@ type tokenStore struct { // TokenInfo data stored in tokenStore type TokenInfo struct { - Token oauth2.Token `json:"oauthToken"` - TenantID string `json:"tenantId"` + Token oauth2.Token `json:"oauthToken"` + TenantID string `json:"tenantId"` + CloudEnvironment string `json:"cloudEnvironment"` } func newTokenStore(path string) (tokenStore, error) { @@ -82,6 +83,9 @@ func (store tokenStore) readToken() (TokenInfo, error) { if err := json.Unmarshal(bytes, &loginInfo); err != nil { return TokenInfo{}, err } + if loginInfo.CloudEnvironment == "" { + loginInfo.CloudEnvironment = AzurePublicCloudName + } return loginInfo, nil } diff --git a/cli/cmd/login/azurelogin.go b/cli/cmd/login/azurelogin.go index fd90170c5..ef6d687c5 100644 --- a/cli/cmd/login/azurelogin.go +++ b/cli/cmd/login/azurelogin.go @@ -40,6 +40,7 @@ func AzureLoginCommand() *cobra.Command { flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use") flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login") flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login") + flags.StringVar(&opts.CloudName, "cloud-name", "", "Name of a registered Azure cloud") return cmd }