Add Azure sovereign cloud support

Signed-off-by: Karol Zadora-Przylecki <karolz@microsoft.com>
This commit is contained in:
Karol Zadora-Przylecki 2021-01-31 12:15:00 -07:00
parent 9063c138ba
commit cc649d958c
15 changed files with 973 additions and 236 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
bin/
dist/
/.vscode/

View File

@ -51,6 +51,7 @@ type LoginParams struct {
TenantID string
ClientID string
ClientSecret string
CloudName string
}
// Validate returns an error if options are not used properly

View File

@ -23,7 +23,9 @@ import (
"github.com/stretchr/testify/mock"
"gotest.tools/v3/assert"
"github.com/docker/compose-cli/aci/login"
"github.com/docker/compose-cli/api/containers"
"golang.org/x/oauth2"
)
func TestGetContainerName(t *testing.T) {
@ -82,7 +84,7 @@ func TestLoginParamsValidate(t *testing.T) {
func TestLoginServicePrincipal(t *testing.T) {
loginService := mockLoginService{}
loginService.On("LoginServicePrincipal", "someID", "secret", "tenant").Return(nil)
loginService.On("LoginServicePrincipal", "someID", "secret", "tenant", "someCloud").Return(nil)
loginBackend := aciCloudService{
loginService: &loginService,
}
@ -91,6 +93,7 @@ func TestLoginServicePrincipal(t *testing.T) {
ClientID: "someID",
ClientSecret: "secret",
TenantID: "tenant",
CloudName: "someCloud",
})
assert.NilError(t, err)
@ -99,13 +102,14 @@ func TestLoginServicePrincipal(t *testing.T) {
func TestLoginWithTenant(t *testing.T) {
loginService := mockLoginService{}
ctx := context.Background()
loginService.On("Login", ctx, "tenant").Return(nil)
loginService.On("Login", ctx, "tenant", "someCloud").Return(nil)
loginBackend := aciCloudService{
loginService: &loginService,
}
err := loginBackend.Login(ctx, LoginParams{
TenantID: "tenant",
TenantID: "tenant",
CloudName: "someCloud",
})
assert.NilError(t, err)
@ -114,12 +118,14 @@ func TestLoginWithTenant(t *testing.T) {
func TestLoginWithoutTenant(t *testing.T) {
loginService := mockLoginService{}
ctx := context.Background()
loginService.On("Login", ctx, "").Return(nil)
loginService.On("Login", ctx, "", "someCloud").Return(nil)
loginBackend := aciCloudService{
loginService: &loginService,
}
err := loginBackend.Login(ctx, LoginParams{})
err := loginBackend.Login(ctx, LoginParams{
CloudName: "someCloud",
})
assert.NilError(t, err)
}
@ -128,13 +134,13 @@ type mockLoginService struct {
mock.Mock
}
func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string) error {
args := s.Called(ctx, requestedTenantID)
func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
args := s.Called(ctx, requestedTenantID, cloudEnvironment)
return args.Error(0)
}
func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
args := s.Called(clientID, clientSecret, tenantID)
func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
args := s.Called(clientID, clientSecret, tenantID, cloudEnvironment)
return args.Error(0)
}
@ -142,3 +148,18 @@ func (s *mockLoginService) Logout(ctx context.Context) error {
args := s.Called(ctx)
return args.Error(0)
}
func (s *mockLoginService) GetTenantID() (string, error) {
args := s.Called()
return args.String(0), args.Error(1)
}
func (s *mockLoginService) GetCloudEnvironment() (login.CloudEnvironment, error) {
args := s.Called()
return args.Get(0).(login.CloudEnvironment), args.Error(1)
}
func (s *mockLoginService) GetValidToken() (oauth2.Token, string, error) {
args := s.Called()
return args.Get(0).(oauth2.Token), args.String(1), args.Error(2)
}

View File

@ -25,7 +25,7 @@ import (
)
type aciCloudService struct {
loginService login.AzureLoginServiceAPI
loginService login.AzureLoginService
}
func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error {
@ -33,10 +33,13 @@ func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error
if !ok {
return errors.New("could not read Azure LoginParams struct from generic parameter")
}
if opts.ClientID != "" {
return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID)
if opts.CloudName == "" {
opts.CloudName = login.AzurePublicCloudName
}
return cs.loginService.Login(ctx, opts.TenantID)
if opts.ClientID != "" {
return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID, opts.CloudName)
}
return cs.loginService.Login(ctx, opts.TenantID, opts.CloudName)
}
func (cs *aciCloudService) Logout(ctx context.Context) error {

View File

@ -47,7 +47,7 @@ const (
type registryHelper interface {
getAllRegistryCredentials() (map[string]types.AuthConfig, error)
autoLoginAcr(registry string) error
autoLoginAcr(registry string, loginService login.AzureLoginService) error
}
type cliRegistryHelper struct {
@ -65,9 +65,19 @@ func newCliRegistryConfLoader() cliRegistryHelper {
}
func getRegistryCredentials(project compose.Project, helper registryHelper) ([]containerinstance.ImageRegistryCredential, error) {
usedRegistries, acrRegistries := getUsedRegistries(project)
loginService, err := login.NewAzureLoginService()
if err != nil {
return nil, err
}
var cloudEnvironment *login.CloudEnvironment = nil
if ce, err := loginService.GetCloudEnvironment(); err != nil {
cloudEnvironment = &ce
}
usedRegistries, acrRegistries := getUsedRegistries(project, cloudEnvironment)
for _, registry := range acrRegistries {
err := helper.autoLoginAcr(registry)
err := helper.autoLoginAcr(registry, loginService)
if err != nil {
fmt.Printf("WARNING: %v\n", err)
fmt.Printf("Could not automatically login to %s from your Azure login. Assuming you already logged in to the ACR registry\n", registry)
@ -116,9 +126,10 @@ func getRegistryCredentials(project compose.Project, helper registryHelper) ([]c
return registryCreds, nil
}
func getUsedRegistries(project compose.Project) (map[string]bool, []string) {
func getUsedRegistries(project compose.Project, ce *login.CloudEnvironment) (map[string]bool, []string) {
usedRegistries := map[string]bool{}
acrRegistries := []string{}
for _, service := range project.Services {
imageName := service.Image
tokens := strings.Split(imageName, "/")
@ -127,24 +138,18 @@ func getUsedRegistries(project compose.Project) (map[string]bool, []string) {
registry = dockerHub
} else if !strings.Contains(registry, ".") {
registry = dockerHub
} else if strings.HasSuffix(registry, login.AcrRegistrySuffix) {
acrRegistries = append(acrRegistries, registry)
} else if ce != nil {
if suffix, present := ce.Suffixes[login.AcrSuffixKey]; present && strings.HasSuffix(registry, suffix) {
acrRegistries = append(acrRegistries, registry)
}
}
usedRegistries[registry] = true
}
return usedRegistries, acrRegistries
}
func (c cliRegistryHelper) autoLoginAcr(registry string) error {
loginService, err := login.NewAzureLoginService()
if err != nil {
return err
}
token, err := loginService.GetValidToken()
if err != nil {
return err
}
tenantID, err := loginService.GetTenantID()
func (c cliRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
token, tenantID, err := loginService.GetValidToken()
if err != nil {
return err
}

View File

@ -25,6 +25,7 @@ import (
"github.com/Azure/go-autorest/autorest/to"
"github.com/compose-spec/compose-go/types"
cliconfigtypes "github.com/docker/cli/cli/config/types"
"github.com/docker/compose-cli/aci/login"
"github.com/stretchr/testify/mock"
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
@ -255,7 +256,7 @@ func (s *MockRegistryHelper) getAllRegistryCredentials() (map[string]cliconfigty
return args.Get(0).(map[string]cliconfigtypes.AuthConfig), args.Error(1)
}
func (s *MockRegistryHelper) autoLoginAcr(registry string) error {
args := s.Called(registry)
func (s *MockRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
args := s.Called(registry, loginService)
return args.Error(0)
}

View File

@ -17,6 +17,8 @@
package login
import (
"encoding/json"
"strconv"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/2019-03-01/resources/mgmt/resources"
@ -24,6 +26,8 @@ import (
"github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2019-12-01/containerinstance"
"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/date"
"github.com/pkg/errors"
"github.com/docker/compose-cli/api/errdefs"
@ -32,8 +36,12 @@ import (
// NewContainerGroupsClient get client toi manipulate containerGrouos
func NewContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
containerGroupsClient := containerinstance.NewContainerGroupsClient(subscriptionID)
err := setupClient(&containerGroupsClient.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return containerinstance.ContainerGroupsClient{}, err
}
containerGroupsClient := containerinstance.NewContainerGroupsClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&containerGroupsClient.Client, authorizer)
if err != nil {
return containerinstance.ContainerGroupsClient{}, err
}
@ -43,68 +51,100 @@ func NewContainerGroupsClient(subscriptionID string) (containerinstance.Containe
return containerGroupsClient, nil
}
func setupClient(aciClient *autorest.Client) error {
func setupClient(aciClient *autorest.Client, auth autorest.Authorizer) {
aciClient.UserAgent = internal.UserAgentName + "/" + internal.Version
auth, err := NewAuthorizerFromLogin()
if err != nil {
return err
}
aciClient.Authorizer = auth
return nil
}
// NewStorageAccountsClient get client to manipulate storage accounts
func NewStorageAccountsClient(subscriptionID string) (storage.AccountsClient, error) {
containerGroupsClient := storage.NewAccountsClient(subscriptionID)
err := setupClient(&containerGroupsClient.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return storage.AccountsClient{}, err
}
containerGroupsClient.PollingDelay = 5 * time.Second
containerGroupsClient.RetryAttempts = 30
containerGroupsClient.RetryDuration = 1 * time.Second
return containerGroupsClient, nil
storageAccuntsClient := storage.NewAccountsClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&storageAccuntsClient.Client, authorizer)
storageAccuntsClient.PollingDelay = 5 * time.Second
storageAccuntsClient.RetryAttempts = 30
storageAccuntsClient.RetryDuration = 1 * time.Second
return storageAccuntsClient, nil
}
// NewFileShareClient get client to manipulate file shares
func NewFileShareClient(subscriptionID string) (storage.FileSharesClient, error) {
containerGroupsClient := storage.NewFileSharesClient(subscriptionID)
err := setupClient(&containerGroupsClient.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return storage.FileSharesClient{}, err
}
containerGroupsClient.PollingDelay = 5 * time.Second
containerGroupsClient.RetryAttempts = 30
containerGroupsClient.RetryDuration = 1 * time.Second
return containerGroupsClient, nil
fileSharesClient := storage.NewFileSharesClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&fileSharesClient.Client, authorizer)
fileSharesClient.PollingDelay = 5 * time.Second
fileSharesClient.RetryAttempts = 30
fileSharesClient.RetryDuration = 1 * time.Second
return fileSharesClient, nil
}
// NewSubscriptionsClient get subscription client
func NewSubscriptionsClient() (subscription.SubscriptionsClient, error) {
subc := subscription.NewSubscriptionsClient()
err := setupClient(&subc.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return subscription.SubscriptionsClient{}, errors.Wrap(errdefs.ErrLoginRequired, err.Error())
}
subc := subscription.NewSubscriptionsClientWithBaseURI(mgmtURL)
setupClient(&subc.Client, authorizer)
return subc, nil
}
// NewGroupsClient get client to manipulate groups
func NewGroupsClient(subscriptionID string) (resources.GroupsClient, error) {
groupsClient := resources.NewGroupsClient(subscriptionID)
err := setupClient(&groupsClient.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return resources.GroupsClient{}, err
}
groupsClient := resources.NewGroupsClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&groupsClient.Client, authorizer)
return groupsClient, nil
}
// NewContainerClient get client to manipulate containers
func NewContainerClient(subscriptionID string) (containerinstance.ContainersClient, error) {
containerClient := containerinstance.NewContainersClient(subscriptionID)
err := setupClient(&containerClient.Client)
authorizer, mgmtURL, err := getClientSetupData()
if err != nil {
return containerinstance.ContainersClient{}, err
}
containerClient := containerinstance.NewContainersClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&containerClient.Client, authorizer)
return containerClient, nil
}
func getClientSetupData() (autorest.Authorizer, string, error) {
return getClientSetupDataImpl(GetTokenStorePath())
}
func getClientSetupDataImpl(tokenStorePath string) (autorest.Authorizer, string, error) {
als, err := newAzureLoginServiceFromPath(tokenStorePath, azureAPIHelper{}, CloudEnvironments)
if err != nil {
return nil, "", err
}
oauthToken, _, err := als.GetValidToken()
if err != nil {
return nil, "", errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
}
ce, err := als.GetCloudEnvironment()
if err != nil {
return nil, "", err
}
token := adal.Token{
AccessToken: oauthToken.AccessToken,
Type: oauthToken.TokenType,
ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
RefreshToken: "",
Resource: "",
}
return autorest.NewBearerAuthorizer(&token), ce.ResourceManagerURL, nil
}

36
aci/login/client_test.go Normal file
View File

@ -0,0 +1,36 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"gotest.tools/v3/assert"
)
func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) {
dir, err := ioutil.TempDir("", "test_store")
assert.NilError(t, err)
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
_, _, err = getClientSetupDataImpl(filepath.Join(dir, tokenStoreFilename))
assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
}

View File

@ -0,0 +1,274 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"github.com/pkg/errors"
)
const (
// AzurePublicCloudName is the moniker of the Azure public cloud
AzurePublicCloudName = "AzureCloud"
// AcrSuffixKey is the well-known name of the DNS suffix for Azure Container Registries
AcrSuffixKey = "acrLoginServer"
// CloudMetadataURLVar is the name of the environment variable that (if defined), points to a URL that should be used by Docker CLI to retrieve cloud metadata
CloudMetadataURLVar = "ARM_CLOUD_METADATA_URL"
// DefaultCloudMetadataURL is the URL of the cloud metadata service maintained by Azure public cloud
DefaultCloudMetadataURL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01"
)
// CloudEnvironmentService exposed metadata about Azure cloud environments
type CloudEnvironmentService interface {
Get(name string) (CloudEnvironment, error)
}
type cloudEnvironmentService struct {
cloudEnvironments map[string]CloudEnvironment
cloudMetadataURL string
// True if we have queried the cloud metadata endpoint already.
// We do it only once per CLI invocation.
metadataQueried bool
}
var (
// CloudEnvironments is the default instance of the CloudEnvironmentService
CloudEnvironments CloudEnvironmentService
)
func init() {
CloudEnvironments = newCloudEnvironmentService()
}
// CloudEnvironmentAuthentication data for logging into, and obtaining tokens for, Azure sovereign clouds
type CloudEnvironmentAuthentication struct {
LoginEndpoint string `json:"loginEndpoint"`
Audiences []string `json:"audiences"`
Tenant string `json:"tenant"`
}
// CloudEnvironment describes Azure sovereign cloud instance (e.g. Azure public, Azure US government, Azure China etc.)
type CloudEnvironment struct {
Name string `json:"name"`
Authentication CloudEnvironmentAuthentication `json:"authentication"`
ResourceManagerURL string `json:"resourceManager"`
Suffixes map[string]string `json:"suffixes"`
}
func newCloudEnvironmentService() *cloudEnvironmentService {
retval := cloudEnvironmentService{
metadataQueried: false,
}
retval.resetCloudMetadata()
return &retval
}
func (ces *cloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
if ce, present := ces.cloudEnvironments[name]; present {
return ce, nil
}
if !ces.metadataQueried {
ces.metadataQueried = true
if ces.cloudMetadataURL == "" {
ces.cloudMetadataURL = os.Getenv(CloudMetadataURLVar)
if _, err := url.ParseRequestURI(ces.cloudMetadataURL); err != nil {
ces.cloudMetadataURL = DefaultCloudMetadataURL
}
}
res, err := http.Get(ces.cloudMetadataURL)
if err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
if res.StatusCode != 200 {
return CloudEnvironment{}, errors.Errorf("Cloud metadata retrieval from '%s' failed: server response was '%s'", ces.cloudMetadataURL, res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
if err = ces.applyCloudMetadata(bytes); err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
}
if ce, present := ces.cloudEnvironments[name]; present {
return ce, nil
}
return CloudEnvironment{}, errors.Errorf("Cloud environment '%s' was not found", name)
}
func (ces *cloudEnvironmentService) applyCloudMetadata(jsonBytes []byte) error {
input := []CloudEnvironment{}
if err := json.Unmarshal(jsonBytes, &input); err != nil {
return err
}
newEnvironments := make(map[string]CloudEnvironment, len(input))
// If _any_ of the submitted data is invalid, we bail out.
for _, e := range input {
if len(e.Name) == 0 {
return errors.New("Azure cloud environment metadata is invalid (an environment with no name has been encountered)")
}
e.normalizeURLs()
if _, err := url.ParseRequestURI(e.Authentication.LoginEndpoint); err != nil {
return errors.Errorf("Metadata of cloud environment '%s' has invalid login endpoint URL: %s", e.Name, e.Authentication.LoginEndpoint)
}
if _, err := url.ParseRequestURI(e.ResourceManagerURL); err != nil {
return errors.Errorf("Metadata of cloud environment '%s' has invalid resource manager URL: %s", e.Name, e.ResourceManagerURL)
}
if len(e.Authentication.Audiences) == 0 {
return errors.Errorf("Metadata of cloud environment '%s' is invalid (no authentication audiences)", e.Name)
}
newEnvironments[e.Name] = e
}
for name, e := range newEnvironments {
ces.cloudEnvironments[name] = e
}
return nil
}
func (ces *cloudEnvironmentService) resetCloudMetadata() {
azurePublicCloud := CloudEnvironment{
Name: AzurePublicCloudName,
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.com",
Audiences: []string{
"https://management.core.windows.net",
"https://management.azure.com",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.azure.com",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.io",
},
}
azureChinaCloud := CloudEnvironment{
Name: "AzureChinaCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.chinacloudapi.cn",
Audiences: []string{
"https://management.core.chinacloudapi.cn",
"https://management.chinacloudapi.cn",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.chinacloudapi.cn",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.cn",
},
}
azureUSGovernment := CloudEnvironment{
Name: "AzureUSGovernment",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.us",
Audiences: []string{
"https://management.core.usgovcloudapi.net",
"https://management.usgovcloudapi.net",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.usgovcloudapi.net",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.us",
},
}
azureGermanCloud := CloudEnvironment{
Name: "AzureGermanCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.de",
Audiences: []string{
"https://management.core.cloudapi.de",
"https://management.microsoftazure.de",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.microsoftazure.de",
// There is no separate container registry suffix for German cloud
Suffixes: map[string]string{},
}
ces.cloudEnvironments = map[string]CloudEnvironment{
azurePublicCloud.Name: azurePublicCloud,
azureChinaCloud.Name: azureChinaCloud,
azureUSGovernment.Name: azureUSGovernment,
azureGermanCloud.Name: azureGermanCloud,
}
}
// GetTenantQueryURL returns an URL that can be used to fetch the list of Azure Active Directory tenants from a given cloud environment
func (ce *CloudEnvironment) GetTenantQueryURL() string {
tenantURL := ce.ResourceManagerURL + "/tenants?api-version=2019-11-01"
return tenantURL
}
// GetTokenScope returns a token scope that fits Docker CLI Azure management API usage
func (ce *CloudEnvironment) GetTokenScope() string {
scope := "offline_access " + ce.ResourceManagerURL + "/.default"
return scope
}
// GetAuthorizeRequestFormat returns a string format that can be used to construct authorization code request in an OAuth2 flow.
// The URL uses login endpoint appropriate for given cloud environment.
func (ce *CloudEnvironment) GetAuthorizeRequestFormat() string {
authorizeFormat := ce.Authentication.LoginEndpoint + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
return authorizeFormat
}
// GetTokenRequestFormat returns a string format that can be used to construct a security token request against Azure Active Directory
func (ce *CloudEnvironment) GetTokenRequestFormat() string {
tokenEndpoint := ce.Authentication.LoginEndpoint + "/%s/oauth2/v2.0/token"
return tokenEndpoint
}
func (ce *CloudEnvironment) normalizeURLs() {
ce.ResourceManagerURL = removeTrailingSlash(ce.ResourceManagerURL)
ce.Authentication.LoginEndpoint = removeTrailingSlash(ce.Authentication.LoginEndpoint)
for i, s := range ce.Authentication.Audiences {
ce.Authentication.Audiences[i] = removeTrailingSlash(s)
}
}
func removeTrailingSlash(s string) string {
return strings.TrimSuffix(s, "/")
}

View File

@ -0,0 +1,187 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"testing"
"gotest.tools/v3/assert"
)
func TestNormalizeCloudEnvironmentURLs(t *testing.T) {
ce := CloudEnvironment{
Name: "SecretCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.here.com/",
Audiences: []string{
"https://audience1",
"https://audience2/",
},
Tenant: "common",
},
ResourceManagerURL: "invalid URL",
}
ce.normalizeURLs()
assert.Equal(t, ce.Authentication.LoginEndpoint, "https://login.here.com")
assert.Equal(t, ce.Authentication.Audiences[0], "https://audience1")
assert.Equal(t, ce.Authentication.Audiences[1], "https://audience2")
}
func TestApplyInvalidCloudMetadataJSON(t *testing.T) {
ce := newCloudEnvironmentService()
bb := []byte(`This isn't really valid JSON`)
err := ce.applyCloudMetadata(bb)
assert.Assert(t, err != nil, "Cloud metadata was invalid, so the application should have failed")
ensureWellKnownCloudMetadata(t, ce)
}
func TestApplyInvalidCloudMetatada(t *testing.T) {
ce := newCloudEnvironmentService()
// No name (moniker) for the cloud
bb := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err := ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "no name")
ensureWellKnownCloudMetadata(t, ce)
// Invalid resource manager URL
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "123"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "invalid resource manager URL")
ensureWellKnownCloudMetadata(t, ce)
// Invalid login endpoint
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "456",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "invalid login endpoint")
ensureWellKnownCloudMetadata(t, ce)
// No audiences
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [ ],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "no authentication audiences")
ensureWellKnownCloudMetadata(t, ce)
}
func TestApplyCloudMetadata(t *testing.T) {
ce := newCloudEnvironmentService()
bb := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err := ce.applyCloudMetadata(bb)
assert.NilError(t, err)
env, err := ce.Get("DockerAzureCloud")
assert.NilError(t, err)
assert.Equal(t, env.Authentication.LoginEndpoint, "https://login.docker.com")
ensureWellKnownCloudMetadata(t, ce)
}
func TestDefaultCloudMetadataPresent(t *testing.T) {
ensureWellKnownCloudMetadata(t, CloudEnvironments)
}
func ensureWellKnownCloudMetadata(t *testing.T, ce CloudEnvironmentService) {
// Make sure well-known public cloud information is still available
_, err := ce.Get(AzurePublicCloudName)
assert.NilError(t, err)
_, err = ce.Get("AzureChinaCloud")
assert.NilError(t, err)
_, err = ce.Get("AzureUSGovernment")
assert.NilError(t, err)
}

View File

@ -39,17 +39,17 @@ var (
)
type apiHelper interface {
queryToken(data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string) error
queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string, ce CloudEnvironment) error
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
getDeviceCodeFlowToken() (adal.Token, error)
getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error)
}
type azureAPIHelper struct{}
func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
func (helper azureAPIHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
deviceconfig := auth.NewDeviceFlowConfig(clientID, "common")
deviceconfig.Resource = azureManagementURL
deviceconfig.Resource = ce.ResourceManagerURL
spToken, err := deviceconfig.ServicePrincipalToken()
if err != nil {
return adal.Token{}, err
@ -57,9 +57,9 @@ func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
return spToken.Token(), err
}
func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
func (helper azureAPIHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
state := randomString("", 10)
authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes)
authURL := fmt.Sprintf(ce.GetAuthorizeRequestFormat(), clientID, redirectURL, state, ce.GetTokenScope())
return openbrowser(authURL)
}
@ -81,8 +81,8 @@ func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizati
return bits, res.StatusCode, nil
}
func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) {
res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
func (helper azureAPIHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error) {
res, err := http.Post(fmt.Sprintf(ce.GetTokenRequestFormat(), tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
if err != nil {
return azureToken{}, err
}

View File

@ -23,13 +23,10 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/Azure/go-autorest/autorest/date"
"github.com/pkg/errors"
"golang.org/x/oauth2"
@ -38,18 +35,6 @@ import (
//go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
const (
// AcrRegistrySuffix suffix for ACR registry images
AcrRegistrySuffix = ".azurecr.io"
activeDirectoryURL = "https://login.microsoftonline.com"
azureManagementURL = "https://management.core.windows.net/"
azureResouceManagementURL = "https://management.azure.com/"
authorizeFormat = activeDirectoryURL + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
tokenEndpoint = activeDirectoryURL + "/%s/oauth2/v2.0/token"
getTenantURL = azureResouceManagementURL + "tenants?api-version=2019-11-01"
// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
// v1 scope like "https://management.azure.com/.default" for ARM access
scopes = "offline_access " + azureResouceManagementURL + ".default"
clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id
)
@ -73,39 +58,41 @@ type (
)
// AzureLoginService Service to log into azure and get authentifier for azure APIs
type AzureLoginService struct {
tokenStore tokenStore
apiHelper apiHelper
}
// AzureLoginServiceAPI interface for Azure login service
type AzureLoginServiceAPI interface {
LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error
Login(ctx context.Context, requestedTenantID string) error
type AzureLoginService interface {
Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error
LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error
Logout(ctx context.Context) error
GetCloudEnvironment() (CloudEnvironment, error)
GetValidToken() (oauth2.Token, string, error)
}
type azureLoginService struct {
tokenStore tokenStore
apiHelper apiHelper
cloudEnvironmentSvc CloudEnvironmentService
}
const tokenStoreFilename = "dockerAccessToken.json"
// NewAzureLoginService creates a NewAzureLoginService
func NewAzureLoginService() (*AzureLoginService, error) {
return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{})
func NewAzureLoginService() (AzureLoginService, error) {
return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}, CloudEnvironments)
}
func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*AzureLoginService, error) {
func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper, ces CloudEnvironmentService) (*azureLoginService, error) {
store, err := newTokenStore(tokenStorePath)
if err != nil {
return nil, err
}
return &AzureLoginService{
tokenStore: store,
apiHelper: helper,
return &azureLoginService{
tokenStore: store,
apiHelper: helper,
cloudEnvironmentSvc: ces,
}, nil
}
// LoginServicePrincipal login with clientId / clientSecret from a service principal.
// The resulting token does not include a refresh token
func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
func (login *azureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
// Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret
creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
@ -121,7 +108,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not read service principal token expiry: %s", err)
}
loginInfo := TokenInfo{TenantID: tenantID, Token: token}
loginInfo := TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: cloudEnvironment}
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@ -130,7 +117,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
}
// Logout remove azure token data
func (login *AzureLoginService) Logout(ctx context.Context) error {
func (login *azureLoginService) Logout(ctx context.Context) error {
err := login.tokenStore.removeData()
if os.IsNotExist(err) {
return errors.New("No Azure login data to be removed")
@ -138,8 +125,14 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
return err
}
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
func (login *azureLoginService) getTenantAndValidateLogin(
ctx context.Context,
accessToken string,
refreshToken string,
requestedTenantID string,
ce CloudEnvironment,
) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, ce.GetTenantQueryURL(), fmt.Sprintf("Bearer %s", accessToken))
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
}
@ -155,11 +148,11 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
if err != nil {
return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
}
tToken, err := login.refreshToken(refreshToken, tenantID)
tToken, err := login.refreshToken(refreshToken, tenantID, ce)
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
}
loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
loginInfo := TokenInfo{TenantID: tenantID, Token: tToken, CloudEnvironment: ce.Name}
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@ -168,7 +161,12 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
}
// Login performs an Azure login through a web browser
func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error {
func (login *azureLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
ce, err := login.cloudEnvironmentSvc.Get(cloudEnvironment)
if err != nil {
return err
}
queryCh := make(chan localResponse, 1)
s, err := NewLocalServer(queryCh)
if err != nil {
@ -183,8 +181,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
}
deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
login.startDeviceCodeFlow(deviceCodeFlowCh)
if err = login.apiHelper.openAzureLoginPage(redirectURL, ce); err != nil {
login.startDeviceCodeFlow(deviceCodeFlowCh, ce)
}
select {
@ -195,7 +193,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
}
token := dcft.token
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
case q := <-queryCh:
if q.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@ -208,14 +206,14 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": code,
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
}
token, err := login.apiHelper.queryToken(data, "organizations")
token, err := login.apiHelper.queryToken(ce, data, "organizations")
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
}
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
}
}
@ -224,10 +222,10 @@ type deviceCodeFlowResponse struct {
err error
}
func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) {
func (login *azureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse, ce CloudEnvironment) {
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
go func() {
token, err := login.apiHelper.getDeviceCodeFlowToken()
token, err := login.apiHelper.getDeviceCodeFlowToken(ce)
if err != nil {
deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
}
@ -276,72 +274,58 @@ func spToOAuthToken(token adal.Token) (oauth2.Token, error) {
return oauthToken, nil
}
// NewAuthorizerFromLogin creates an authorizer based on login access token
func NewAuthorizerFromLogin() (autorest.Authorizer, error) {
return newAuthorizerFromLoginStorePath(GetTokenStorePath())
}
func newAuthorizerFromLoginStorePath(storeTokenPath string) (autorest.Authorizer, error) {
login, err := newAzureLoginServiceFromPath(storeTokenPath, azureAPIHelper{})
if err != nil {
return nil, err
}
oauthToken, err := login.GetValidToken()
if err != nil {
return nil, errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
}
token := adal.Token{
AccessToken: oauthToken.AccessToken,
Type: oauthToken.TokenType,
ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
RefreshToken: "",
Resource: "",
}
return autorest.NewBearerAuthorizer(&token), nil
}
// GetTenantID returns tenantID for current login
func (login AzureLoginService) GetTenantID() (string, error) {
// GetValidToken returns an access token and associated tenant ID.
// Will refresh the token as necessary.
func (login *azureLoginService) GetValidToken() (oauth2.Token, string, error) {
loginInfo, err := login.tokenStore.readToken()
if err != nil {
return "", err
}
return loginInfo.TenantID, err
}
// GetValidToken returns an access token. Refresh token if needed
func (login *AzureLoginService) GetValidToken() (oauth2.Token, error) {
loginInfo, err := login.tokenStore.readToken()
if err != nil {
return oauth2.Token{}, err
return oauth2.Token{}, "", err
}
token := loginInfo.Token
if token.Valid() {
return token, nil
}
tenantID := loginInfo.TenantID
token, err = login.refreshToken(token.RefreshToken, tenantID)
if err != nil {
return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
if token.Valid() {
return token, tenantID, nil
}
err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token})
ce, err := login.cloudEnvironmentSvc.Get(loginInfo.CloudEnvironment)
if err != nil {
return oauth2.Token{}, err
return oauth2.Token{}, "", errors.Wrap(err, "access token request failed--cloud environment could not be determined.")
}
return token, nil
token, err = login.refreshToken(token.RefreshToken, tenantID, ce)
if err != nil {
return oauth2.Token{}, "", errors.Wrap(err, "access token request failed. Maybe you need to login to Azure again.")
}
err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: ce.Name})
if err != nil {
return oauth2.Token{}, "", err
}
return token, tenantID, nil
}
func (login *AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) {
// GeCloudEnvironment returns the cloud environment associated with the current authentication token (if we have one)
func (login *azureLoginService) GetCloudEnvironment() (CloudEnvironment, error) {
tokenInfo, err := login.tokenStore.readToken()
if err != nil {
return CloudEnvironment{}, err
}
cloudEnvironment, err := login.cloudEnvironmentSvc.Get(tokenInfo.CloudEnvironment)
if err != nil {
return CloudEnvironment{}, err
}
return cloudEnvironment, nil
}
func (login *azureLoginService) refreshToken(currentRefreshToken string, tenantID string, ce CloudEnvironment) (oauth2.Token, error) {
data := url.Values{
"grant_type": []string{"refresh_token"},
"client_id": []string{clientID},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"refresh_token": []string{currentRefreshToken},
}
token, err := login.apiHelper.queryToken(data, tenantID)
token, err := login.apiHelper.queryToken(ce, data, tenantID)
if err != nil {
return oauth2.Token{}, err
}

View File

@ -21,10 +21,12 @@ import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"sync/atomic"
"testing"
"time"
@ -36,7 +38,7 @@ import (
"golang.org/x/oauth2"
)
func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, error) {
func testLoginService(t *testing.T, apiHelperMock *MockAzureHelper, cloudEnvironmentSvc CloudEnvironmentService) (*azureLoginService, error) {
dir, err := ioutil.TempDir("", "test_store")
if err != nil {
return nil, err
@ -44,20 +46,45 @@ func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, err
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), m)
ces := CloudEnvironments
if cloudEnvironmentSvc != nil {
ces = cloudEnvironmentSvc
}
return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), apiHelperMock, ces)
}
func TestRefreshInValidToken(t *testing.T) {
data := refreshTokenData("refreshToken")
m := &MockAzureHelper{}
m.On("queryToken", data, "123456").Return(azureToken{
data := url.Values{
"grant_type": []string{"refresh_token"},
"client_id": []string{clientID},
"scope": []string{"offline_access https://management.docker.com/.default"},
"refresh_token": []string{"refreshToken"},
}
helperMock := &MockAzureHelper{}
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "123456").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, m)
cloudEnvironmentSvcMock := &MockCloudEnvironmentService{}
cloudEnvironmentSvcMock.On("Get", "AzureDockerCloud").Return(CloudEnvironment{
Name: "AzureDockerCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.docker.com",
Audiences: []string{
"https://management.docker.com",
"https://management-ext.docker.com",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.docker.com",
Suffixes: map[string]string{},
}, nil)
azureLogin, err := testLoginService(t, helperMock, cloudEnvironmentSvcMock)
assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456",
@ -67,33 +94,29 @@ func TestRefreshInValidToken(t *testing.T) {
Expiry: time.Now().Add(-1 * time.Hour),
TokenType: "Bearer",
},
CloudEnvironment: "AzureDockerCloud",
})
assert.NilError(t, err)
token, _ := azureLogin.GetValidToken()
token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, tenantID, "123456")
assert.Equal(t, token.AccessToken, "newAccessToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(token.Expiry))
storedToken, _ := azureLogin.tokenStore.readToken()
storedToken, err := azureLogin.tokenStore.readToken()
assert.NilError(t, err)
assert.Equal(t, storedToken.Token.AccessToken, "newAccessToken")
assert.Equal(t, storedToken.Token.RefreshToken, "newRefreshToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(storedToken.Token.Expiry))
}
func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) {
dir, err := ioutil.TempDir("", "test_store")
assert.NilError(t, err)
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
_, err = newAuthorizerFromLoginStorePath(filepath.Join(dir, tokenStoreFilename))
assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
assert.Equal(t, storedToken.CloudEnvironment, "AzureDockerCloud")
}
func TestDoesNotRefreshValidToken(t *testing.T) {
expiryDate := time.Now().Add(1 * time.Hour)
azureLogin, err := testLoginService(t, nil)
azureLogin, err := testLoginService(t, nil, nil)
assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456",
@ -103,25 +126,55 @@ func TestDoesNotRefreshValidToken(t *testing.T) {
Expiry: expiryDate,
TokenType: "Bearer",
},
CloudEnvironment: AzurePublicCloudName,
})
assert.NilError(t, err)
token, _ := azureLogin.GetValidToken()
token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, token.AccessToken, "accessToken")
assert.Equal(t, tenantID, "123456")
}
func TestTokenStoreAssumesAzurePublicCloud(t *testing.T) {
expiryDate := time.Now().Add(1 * time.Hour)
azureLogin, err := testLoginService(t, nil, nil)
assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456",
Token: oauth2.Token{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
Expiry: expiryDate,
TokenType: "Bearer",
},
// Simulates upgrade from older version of Docker CLI that did not have cloud environment concept
CloudEnvironment: "",
})
assert.NilError(t, err)
token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, tenantID, "123456")
assert.Equal(t, token.AccessToken, "accessToken")
ce, err := azureLogin.GetCloudEnvironment()
assert.NilError(t, err)
assert.Equal(t, ce.Name, AzurePublicCloudName)
}
func TestInvalidLogin(t *testing.T) {
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL := args.Get(0).(string)
err := queryKeyValue(redirectURL, "error", "access denied: login failed")
assert.NilError(t, err)
}).Return(nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(context.TODO(), "", AzurePublicCloudName)
assert.Error(t, err, "no login code: login failed")
}
@ -129,19 +182,22 @@ func TestValidLogin(t *testing.T) {
var redirectURL string
ctx := context.TODO()
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
@ -153,18 +209,18 @@ func TestValidLogin(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "")
err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -174,24 +230,28 @@ func TestValidLogin(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
}
func TestValidLoginRequestedTenant(t *testing.T) {
var redirectURL string
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
@ -205,18 +265,18 @@ func TestValidLoginRequestedTenant(t *testing.T) {
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -226,24 +286,28 @@ func TestValidLoginRequestedTenant(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
}
func TestLoginNoTenant(t *testing.T) {
var redirectURL string
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
@ -255,31 +319,34 @@ func TestLoginNoTenant(t *testing.T) {
ctx := context.TODO()
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
}
func TestLoginRequestedTenantNotFound(t *testing.T) {
var redirectURL string
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
@ -291,31 +358,34 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
ctx := context.TODO()
authBody := `{"value":[]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "")
err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.Error(t, err, "could not find azure tenant: login failed")
}
func TestLoginAuthorizationFailed(t *testing.T) {
var redirectURL string
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
@ -328,35 +398,38 @@ func TestLoginAuthorizationFailed(t *testing.T) {
authBody := `[access denied]`
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "")
err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
}
func TestValidThroughDeviceCodeFlow(t *testing.T) {
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser"))
m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Return(errors.New("Could not open browser"))
m.On("getDeviceCodeFlowToken", mock.AnythingOfType("CloudEnvironment")).Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, m)
azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "")
err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -366,13 +439,110 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
}
func refreshTokenData(refreshToken string) url.Values {
func TestNonstandardCloudEnvironment(t *testing.T) {
dockerCloudMetadata := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "AzureDockerCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
var metadataReqCount int32 = 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write(dockerCloudMetadata)
assert.NilError(t, err)
atomic.AddInt32(&metadataReqCount, 1)
}))
defer srv.Close()
cloudMetadataURL, cloudMetadataURLSet := os.LookupEnv(CloudMetadataURLVar)
if cloudMetadataURLSet {
defer func() {
err := os.Setenv(CloudMetadataURLVar, cloudMetadataURL)
assert.NilError(t, err)
}()
}
err := os.Setenv(CloudMetadataURLVar, srv.URL)
assert.NilError(t, err)
ctx := context.TODO()
ces := newCloudEnvironmentService()
ces.cloudMetadataURL = srv.URL
dockerCloudEnv, err := ces.Get("AzureDockerCloud")
assert.NilError(t, err)
helperMock := &MockAzureHelper{}
var redirectURL string
helperMock.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{dockerCloudEnv.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
RefreshToken: "firstRefreshToken",
AccessToken: "firstAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
authBody := `{"value":[{"id":"/tenants/F5773994-FE88-482E-9E33-6E799D250416","tenantId":"F5773994-FE88-482E-9E33-6E799D250416"}]}`
helperMock.On("queryAPIWithHeader", ctx, dockerCloudEnv.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken", dockerCloudEnv)
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "F5773994-FE88-482E-9E33-6E799D250416").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, helperMock, ces)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "", "AzureDockerCloud")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
assert.NilError(t, err)
assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "F5773994-FE88-482E-9E33-6E799D250416")
assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureDockerCloud")
assert.Equal(t, metadataReqCount, int32(1))
}
// Don't warn about refreshToken parameter taking the same value for all invocations
// nolint:unparam
func refreshTokenData(refreshToken string, ce CloudEnvironment) url.Values {
return url.Values{
"grant_type": []string{"refresh_token"},
"client_id": []string{clientID},
"scope": []string{scopes},
"scope": []string{ce.GetTokenScope()},
"refresh_token": []string{refreshToken},
}
}
@ -394,13 +564,13 @@ type MockAzureHelper struct {
mock.Mock
}
func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) {
args := s.Called()
func (s *MockAzureHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
args := s.Called(ce)
return args.Get(0).(adal.Token), args.Error(1)
}
func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
args := s.Called(data, tenantID)
func (s *MockAzureHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (token azureToken, err error) {
args := s.Called(ce, data, tenantID)
return args.Get(0).(azureToken), args.Error(1)
}
@ -409,7 +579,16 @@ func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationU
return args.Get(0).([]byte), args.Int(1), args.Error(2)
}
func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error {
args := s.Called(redirectURL)
func (s *MockAzureHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
args := s.Called(redirectURL, ce)
return args.Error(0)
}
type MockCloudEnvironmentService struct {
mock.Mock
}
func (s *MockCloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
args := s.Called(name)
return args.Get(0).(CloudEnvironment), args.Error(1)
}

View File

@ -34,8 +34,9 @@ type tokenStore struct {
// TokenInfo data stored in tokenStore
type TokenInfo struct {
Token oauth2.Token `json:"oauthToken"`
TenantID string `json:"tenantId"`
Token oauth2.Token `json:"oauthToken"`
TenantID string `json:"tenantId"`
CloudEnvironment string `json:"cloudEnvironment"`
}
func newTokenStore(path string) (tokenStore, error) {
@ -82,6 +83,9 @@ func (store tokenStore) readToken() (TokenInfo, error) {
if err := json.Unmarshal(bytes, &loginInfo); err != nil {
return TokenInfo{}, err
}
if loginInfo.CloudEnvironment == "" {
loginInfo.CloudEnvironment = AzurePublicCloudName
}
return loginInfo, nil
}

View File

@ -40,6 +40,7 @@ func AzureLoginCommand() *cobra.Command {
flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use")
flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login")
flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login")
flags.StringVar(&opts.CloudName, "cloud-name", "", "Name of a registered Azure cloud")
return cmd
}