Add Azure sovereign cloud support

Signed-off-by: Karol Zadora-Przylecki <karolz@microsoft.com>
This commit is contained in:
Karol Zadora-Przylecki 2021-01-31 12:15:00 -07:00
parent 9063c138ba
commit cc649d958c
15 changed files with 973 additions and 236 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
bin/ bin/
dist/ dist/
/.vscode/

View File

@ -51,6 +51,7 @@ type LoginParams struct {
TenantID string TenantID string
ClientID string ClientID string
ClientSecret string ClientSecret string
CloudName string
} }
// Validate returns an error if options are not used properly // Validate returns an error if options are not used properly

View File

@ -23,7 +23,9 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
"github.com/docker/compose-cli/aci/login"
"github.com/docker/compose-cli/api/containers" "github.com/docker/compose-cli/api/containers"
"golang.org/x/oauth2"
) )
func TestGetContainerName(t *testing.T) { func TestGetContainerName(t *testing.T) {
@ -82,7 +84,7 @@ func TestLoginParamsValidate(t *testing.T) {
func TestLoginServicePrincipal(t *testing.T) { func TestLoginServicePrincipal(t *testing.T) {
loginService := mockLoginService{} loginService := mockLoginService{}
loginService.On("LoginServicePrincipal", "someID", "secret", "tenant").Return(nil) loginService.On("LoginServicePrincipal", "someID", "secret", "tenant", "someCloud").Return(nil)
loginBackend := aciCloudService{ loginBackend := aciCloudService{
loginService: &loginService, loginService: &loginService,
} }
@ -91,6 +93,7 @@ func TestLoginServicePrincipal(t *testing.T) {
ClientID: "someID", ClientID: "someID",
ClientSecret: "secret", ClientSecret: "secret",
TenantID: "tenant", TenantID: "tenant",
CloudName: "someCloud",
}) })
assert.NilError(t, err) assert.NilError(t, err)
@ -99,13 +102,14 @@ func TestLoginServicePrincipal(t *testing.T) {
func TestLoginWithTenant(t *testing.T) { func TestLoginWithTenant(t *testing.T) {
loginService := mockLoginService{} loginService := mockLoginService{}
ctx := context.Background() ctx := context.Background()
loginService.On("Login", ctx, "tenant").Return(nil) loginService.On("Login", ctx, "tenant", "someCloud").Return(nil)
loginBackend := aciCloudService{ loginBackend := aciCloudService{
loginService: &loginService, loginService: &loginService,
} }
err := loginBackend.Login(ctx, LoginParams{ err := loginBackend.Login(ctx, LoginParams{
TenantID: "tenant", TenantID: "tenant",
CloudName: "someCloud",
}) })
assert.NilError(t, err) assert.NilError(t, err)
@ -114,12 +118,14 @@ func TestLoginWithTenant(t *testing.T) {
func TestLoginWithoutTenant(t *testing.T) { func TestLoginWithoutTenant(t *testing.T) {
loginService := mockLoginService{} loginService := mockLoginService{}
ctx := context.Background() ctx := context.Background()
loginService.On("Login", ctx, "").Return(nil) loginService.On("Login", ctx, "", "someCloud").Return(nil)
loginBackend := aciCloudService{ loginBackend := aciCloudService{
loginService: &loginService, loginService: &loginService,
} }
err := loginBackend.Login(ctx, LoginParams{}) err := loginBackend.Login(ctx, LoginParams{
CloudName: "someCloud",
})
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -128,13 +134,13 @@ type mockLoginService struct {
mock.Mock mock.Mock
} }
func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string) error { func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
args := s.Called(ctx, requestedTenantID) args := s.Called(ctx, requestedTenantID, cloudEnvironment)
return args.Error(0) return args.Error(0)
} }
func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
args := s.Called(clientID, clientSecret, tenantID) args := s.Called(clientID, clientSecret, tenantID, cloudEnvironment)
return args.Error(0) return args.Error(0)
} }
@ -142,3 +148,18 @@ func (s *mockLoginService) Logout(ctx context.Context) error {
args := s.Called(ctx) args := s.Called(ctx)
return args.Error(0) return args.Error(0)
} }
func (s *mockLoginService) GetTenantID() (string, error) {
args := s.Called()
return args.String(0), args.Error(1)
}
func (s *mockLoginService) GetCloudEnvironment() (login.CloudEnvironment, error) {
args := s.Called()
return args.Get(0).(login.CloudEnvironment), args.Error(1)
}
func (s *mockLoginService) GetValidToken() (oauth2.Token, string, error) {
args := s.Called()
return args.Get(0).(oauth2.Token), args.String(1), args.Error(2)
}

View File

@ -25,7 +25,7 @@ import (
) )
type aciCloudService struct { type aciCloudService struct {
loginService login.AzureLoginServiceAPI loginService login.AzureLoginService
} }
func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error { func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error {
@ -33,10 +33,13 @@ func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error
if !ok { if !ok {
return errors.New("could not read Azure LoginParams struct from generic parameter") return errors.New("could not read Azure LoginParams struct from generic parameter")
} }
if opts.ClientID != "" { if opts.CloudName == "" {
return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID) opts.CloudName = login.AzurePublicCloudName
} }
return cs.loginService.Login(ctx, opts.TenantID) if opts.ClientID != "" {
return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID, opts.CloudName)
}
return cs.loginService.Login(ctx, opts.TenantID, opts.CloudName)
} }
func (cs *aciCloudService) Logout(ctx context.Context) error { func (cs *aciCloudService) Logout(ctx context.Context) error {

View File

@ -47,7 +47,7 @@ const (
type registryHelper interface { type registryHelper interface {
getAllRegistryCredentials() (map[string]types.AuthConfig, error) getAllRegistryCredentials() (map[string]types.AuthConfig, error)
autoLoginAcr(registry string) error autoLoginAcr(registry string, loginService login.AzureLoginService) error
} }
type cliRegistryHelper struct { type cliRegistryHelper struct {
@ -65,9 +65,19 @@ func newCliRegistryConfLoader() cliRegistryHelper {
} }
func getRegistryCredentials(project compose.Project, helper registryHelper) ([]containerinstance.ImageRegistryCredential, error) { func getRegistryCredentials(project compose.Project, helper registryHelper) ([]containerinstance.ImageRegistryCredential, error) {
usedRegistries, acrRegistries := getUsedRegistries(project) loginService, err := login.NewAzureLoginService()
if err != nil {
return nil, err
}
var cloudEnvironment *login.CloudEnvironment = nil
if ce, err := loginService.GetCloudEnvironment(); err != nil {
cloudEnvironment = &ce
}
usedRegistries, acrRegistries := getUsedRegistries(project, cloudEnvironment)
for _, registry := range acrRegistries { for _, registry := range acrRegistries {
err := helper.autoLoginAcr(registry) err := helper.autoLoginAcr(registry, loginService)
if err != nil { if err != nil {
fmt.Printf("WARNING: %v\n", err) fmt.Printf("WARNING: %v\n", err)
fmt.Printf("Could not automatically login to %s from your Azure login. Assuming you already logged in to the ACR registry\n", registry) fmt.Printf("Could not automatically login to %s from your Azure login. Assuming you already logged in to the ACR registry\n", registry)
@ -116,9 +126,10 @@ func getRegistryCredentials(project compose.Project, helper registryHelper) ([]c
return registryCreds, nil return registryCreds, nil
} }
func getUsedRegistries(project compose.Project) (map[string]bool, []string) { func getUsedRegistries(project compose.Project, ce *login.CloudEnvironment) (map[string]bool, []string) {
usedRegistries := map[string]bool{} usedRegistries := map[string]bool{}
acrRegistries := []string{} acrRegistries := []string{}
for _, service := range project.Services { for _, service := range project.Services {
imageName := service.Image imageName := service.Image
tokens := strings.Split(imageName, "/") tokens := strings.Split(imageName, "/")
@ -127,24 +138,18 @@ func getUsedRegistries(project compose.Project) (map[string]bool, []string) {
registry = dockerHub registry = dockerHub
} else if !strings.Contains(registry, ".") { } else if !strings.Contains(registry, ".") {
registry = dockerHub registry = dockerHub
} else if strings.HasSuffix(registry, login.AcrRegistrySuffix) { } else if ce != nil {
acrRegistries = append(acrRegistries, registry) if suffix, present := ce.Suffixes[login.AcrSuffixKey]; present && strings.HasSuffix(registry, suffix) {
acrRegistries = append(acrRegistries, registry)
}
} }
usedRegistries[registry] = true usedRegistries[registry] = true
} }
return usedRegistries, acrRegistries return usedRegistries, acrRegistries
} }
func (c cliRegistryHelper) autoLoginAcr(registry string) error { func (c cliRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
loginService, err := login.NewAzureLoginService() token, tenantID, err := loginService.GetValidToken()
if err != nil {
return err
}
token, err := loginService.GetValidToken()
if err != nil {
return err
}
tenantID, err := loginService.GetTenantID()
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,6 +25,7 @@ import (
"github.com/Azure/go-autorest/autorest/to" "github.com/Azure/go-autorest/autorest/to"
"github.com/compose-spec/compose-go/types" "github.com/compose-spec/compose-go/types"
cliconfigtypes "github.com/docker/cli/cli/config/types" cliconfigtypes "github.com/docker/cli/cli/config/types"
"github.com/docker/compose-cli/aci/login"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp" is "gotest.tools/v3/assert/cmp"
@ -255,7 +256,7 @@ func (s *MockRegistryHelper) getAllRegistryCredentials() (map[string]cliconfigty
return args.Get(0).(map[string]cliconfigtypes.AuthConfig), args.Error(1) return args.Get(0).(map[string]cliconfigtypes.AuthConfig), args.Error(1)
} }
func (s *MockRegistryHelper) autoLoginAcr(registry string) error { func (s *MockRegistryHelper) autoLoginAcr(registry string, loginService login.AzureLoginService) error {
args := s.Called(registry) args := s.Called(registry, loginService)
return args.Error(0) return args.Error(0)
} }

View File

@ -17,6 +17,8 @@
package login package login
import ( import (
"encoding/json"
"strconv"
"time" "time"
"github.com/Azure/azure-sdk-for-go/profiles/2019-03-01/resources/mgmt/resources" "github.com/Azure/azure-sdk-for-go/profiles/2019-03-01/resources/mgmt/resources"
@ -24,6 +26,8 @@ import (
"github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2019-12-01/containerinstance" "github.com/Azure/azure-sdk-for-go/services/containerinstance/mgmt/2019-12-01/containerinstance"
"github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage" "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage"
"github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/date"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/docker/compose-cli/api/errdefs" "github.com/docker/compose-cli/api/errdefs"
@ -32,8 +36,12 @@ import (
// NewContainerGroupsClient get client toi manipulate containerGrouos // NewContainerGroupsClient get client toi manipulate containerGrouos
func NewContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) { func NewContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
containerGroupsClient := containerinstance.NewContainerGroupsClient(subscriptionID) authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&containerGroupsClient.Client) if err != nil {
return containerinstance.ContainerGroupsClient{}, err
}
containerGroupsClient := containerinstance.NewContainerGroupsClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&containerGroupsClient.Client, authorizer)
if err != nil { if err != nil {
return containerinstance.ContainerGroupsClient{}, err return containerinstance.ContainerGroupsClient{}, err
} }
@ -43,68 +51,100 @@ func NewContainerGroupsClient(subscriptionID string) (containerinstance.Containe
return containerGroupsClient, nil return containerGroupsClient, nil
} }
func setupClient(aciClient *autorest.Client) error { func setupClient(aciClient *autorest.Client, auth autorest.Authorizer) {
aciClient.UserAgent = internal.UserAgentName + "/" + internal.Version aciClient.UserAgent = internal.UserAgentName + "/" + internal.Version
auth, err := NewAuthorizerFromLogin()
if err != nil {
return err
}
aciClient.Authorizer = auth aciClient.Authorizer = auth
return nil
} }
// NewStorageAccountsClient get client to manipulate storage accounts // NewStorageAccountsClient get client to manipulate storage accounts
func NewStorageAccountsClient(subscriptionID string) (storage.AccountsClient, error) { func NewStorageAccountsClient(subscriptionID string) (storage.AccountsClient, error) {
containerGroupsClient := storage.NewAccountsClient(subscriptionID) authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&containerGroupsClient.Client)
if err != nil { if err != nil {
return storage.AccountsClient{}, err return storage.AccountsClient{}, err
} }
containerGroupsClient.PollingDelay = 5 * time.Second storageAccuntsClient := storage.NewAccountsClientWithBaseURI(mgmtURL, subscriptionID)
containerGroupsClient.RetryAttempts = 30 setupClient(&storageAccuntsClient.Client, authorizer)
containerGroupsClient.RetryDuration = 1 * time.Second storageAccuntsClient.PollingDelay = 5 * time.Second
return containerGroupsClient, nil storageAccuntsClient.RetryAttempts = 30
storageAccuntsClient.RetryDuration = 1 * time.Second
return storageAccuntsClient, nil
} }
// NewFileShareClient get client to manipulate file shares // NewFileShareClient get client to manipulate file shares
func NewFileShareClient(subscriptionID string) (storage.FileSharesClient, error) { func NewFileShareClient(subscriptionID string) (storage.FileSharesClient, error) {
containerGroupsClient := storage.NewFileSharesClient(subscriptionID) authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&containerGroupsClient.Client)
if err != nil { if err != nil {
return storage.FileSharesClient{}, err return storage.FileSharesClient{}, err
} }
containerGroupsClient.PollingDelay = 5 * time.Second fileSharesClient := storage.NewFileSharesClientWithBaseURI(mgmtURL, subscriptionID)
containerGroupsClient.RetryAttempts = 30 setupClient(&fileSharesClient.Client, authorizer)
containerGroupsClient.RetryDuration = 1 * time.Second fileSharesClient.PollingDelay = 5 * time.Second
return containerGroupsClient, nil fileSharesClient.RetryAttempts = 30
fileSharesClient.RetryDuration = 1 * time.Second
return fileSharesClient, nil
} }
// NewSubscriptionsClient get subscription client // NewSubscriptionsClient get subscription client
func NewSubscriptionsClient() (subscription.SubscriptionsClient, error) { func NewSubscriptionsClient() (subscription.SubscriptionsClient, error) {
subc := subscription.NewSubscriptionsClient() authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&subc.Client)
if err != nil { if err != nil {
return subscription.SubscriptionsClient{}, errors.Wrap(errdefs.ErrLoginRequired, err.Error()) return subscription.SubscriptionsClient{}, errors.Wrap(errdefs.ErrLoginRequired, err.Error())
} }
subc := subscription.NewSubscriptionsClientWithBaseURI(mgmtURL)
setupClient(&subc.Client, authorizer)
return subc, nil return subc, nil
} }
// NewGroupsClient get client to manipulate groups // NewGroupsClient get client to manipulate groups
func NewGroupsClient(subscriptionID string) (resources.GroupsClient, error) { func NewGroupsClient(subscriptionID string) (resources.GroupsClient, error) {
groupsClient := resources.NewGroupsClient(subscriptionID) authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&groupsClient.Client)
if err != nil { if err != nil {
return resources.GroupsClient{}, err return resources.GroupsClient{}, err
} }
groupsClient := resources.NewGroupsClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&groupsClient.Client, authorizer)
return groupsClient, nil return groupsClient, nil
} }
// NewContainerClient get client to manipulate containers // NewContainerClient get client to manipulate containers
func NewContainerClient(subscriptionID string) (containerinstance.ContainersClient, error) { func NewContainerClient(subscriptionID string) (containerinstance.ContainersClient, error) {
containerClient := containerinstance.NewContainersClient(subscriptionID) authorizer, mgmtURL, err := getClientSetupData()
err := setupClient(&containerClient.Client)
if err != nil { if err != nil {
return containerinstance.ContainersClient{}, err return containerinstance.ContainersClient{}, err
} }
containerClient := containerinstance.NewContainersClientWithBaseURI(mgmtURL, subscriptionID)
setupClient(&containerClient.Client, authorizer)
return containerClient, nil return containerClient, nil
} }
func getClientSetupData() (autorest.Authorizer, string, error) {
return getClientSetupDataImpl(GetTokenStorePath())
}
func getClientSetupDataImpl(tokenStorePath string) (autorest.Authorizer, string, error) {
als, err := newAzureLoginServiceFromPath(tokenStorePath, azureAPIHelper{}, CloudEnvironments)
if err != nil {
return nil, "", err
}
oauthToken, _, err := als.GetValidToken()
if err != nil {
return nil, "", errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
}
ce, err := als.GetCloudEnvironment()
if err != nil {
return nil, "", err
}
token := adal.Token{
AccessToken: oauthToken.AccessToken,
Type: oauthToken.TokenType,
ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
RefreshToken: "",
Resource: "",
}
return autorest.NewBearerAuthorizer(&token), ce.ResourceManagerURL, nil
}

36
aci/login/client_test.go Normal file
View File

@ -0,0 +1,36 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"gotest.tools/v3/assert"
)
func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) {
dir, err := ioutil.TempDir("", "test_store")
assert.NilError(t, err)
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
_, _, err = getClientSetupDataImpl(filepath.Join(dir, tokenStoreFilename))
assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
}

View File

@ -0,0 +1,274 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"strings"
"github.com/pkg/errors"
)
const (
// AzurePublicCloudName is the moniker of the Azure public cloud
AzurePublicCloudName = "AzureCloud"
// AcrSuffixKey is the well-known name of the DNS suffix for Azure Container Registries
AcrSuffixKey = "acrLoginServer"
// CloudMetadataURLVar is the name of the environment variable that (if defined), points to a URL that should be used by Docker CLI to retrieve cloud metadata
CloudMetadataURLVar = "ARM_CLOUD_METADATA_URL"
// DefaultCloudMetadataURL is the URL of the cloud metadata service maintained by Azure public cloud
DefaultCloudMetadataURL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01"
)
// CloudEnvironmentService exposed metadata about Azure cloud environments
type CloudEnvironmentService interface {
Get(name string) (CloudEnvironment, error)
}
type cloudEnvironmentService struct {
cloudEnvironments map[string]CloudEnvironment
cloudMetadataURL string
// True if we have queried the cloud metadata endpoint already.
// We do it only once per CLI invocation.
metadataQueried bool
}
var (
// CloudEnvironments is the default instance of the CloudEnvironmentService
CloudEnvironments CloudEnvironmentService
)
func init() {
CloudEnvironments = newCloudEnvironmentService()
}
// CloudEnvironmentAuthentication data for logging into, and obtaining tokens for, Azure sovereign clouds
type CloudEnvironmentAuthentication struct {
LoginEndpoint string `json:"loginEndpoint"`
Audiences []string `json:"audiences"`
Tenant string `json:"tenant"`
}
// CloudEnvironment describes Azure sovereign cloud instance (e.g. Azure public, Azure US government, Azure China etc.)
type CloudEnvironment struct {
Name string `json:"name"`
Authentication CloudEnvironmentAuthentication `json:"authentication"`
ResourceManagerURL string `json:"resourceManager"`
Suffixes map[string]string `json:"suffixes"`
}
func newCloudEnvironmentService() *cloudEnvironmentService {
retval := cloudEnvironmentService{
metadataQueried: false,
}
retval.resetCloudMetadata()
return &retval
}
func (ces *cloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
if ce, present := ces.cloudEnvironments[name]; present {
return ce, nil
}
if !ces.metadataQueried {
ces.metadataQueried = true
if ces.cloudMetadataURL == "" {
ces.cloudMetadataURL = os.Getenv(CloudMetadataURLVar)
if _, err := url.ParseRequestURI(ces.cloudMetadataURL); err != nil {
ces.cloudMetadataURL = DefaultCloudMetadataURL
}
}
res, err := http.Get(ces.cloudMetadataURL)
if err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
if res.StatusCode != 200 {
return CloudEnvironment{}, errors.Errorf("Cloud metadata retrieval from '%s' failed: server response was '%s'", ces.cloudMetadataURL, res.Status)
}
bytes, err := ioutil.ReadAll(res.Body)
if err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
if err = ces.applyCloudMetadata(bytes); err != nil {
return CloudEnvironment{}, fmt.Errorf("Cloud metadata retrieval from '%s' failed: %w", ces.cloudMetadataURL, err)
}
}
if ce, present := ces.cloudEnvironments[name]; present {
return ce, nil
}
return CloudEnvironment{}, errors.Errorf("Cloud environment '%s' was not found", name)
}
func (ces *cloudEnvironmentService) applyCloudMetadata(jsonBytes []byte) error {
input := []CloudEnvironment{}
if err := json.Unmarshal(jsonBytes, &input); err != nil {
return err
}
newEnvironments := make(map[string]CloudEnvironment, len(input))
// If _any_ of the submitted data is invalid, we bail out.
for _, e := range input {
if len(e.Name) == 0 {
return errors.New("Azure cloud environment metadata is invalid (an environment with no name has been encountered)")
}
e.normalizeURLs()
if _, err := url.ParseRequestURI(e.Authentication.LoginEndpoint); err != nil {
return errors.Errorf("Metadata of cloud environment '%s' has invalid login endpoint URL: %s", e.Name, e.Authentication.LoginEndpoint)
}
if _, err := url.ParseRequestURI(e.ResourceManagerURL); err != nil {
return errors.Errorf("Metadata of cloud environment '%s' has invalid resource manager URL: %s", e.Name, e.ResourceManagerURL)
}
if len(e.Authentication.Audiences) == 0 {
return errors.Errorf("Metadata of cloud environment '%s' is invalid (no authentication audiences)", e.Name)
}
newEnvironments[e.Name] = e
}
for name, e := range newEnvironments {
ces.cloudEnvironments[name] = e
}
return nil
}
func (ces *cloudEnvironmentService) resetCloudMetadata() {
azurePublicCloud := CloudEnvironment{
Name: AzurePublicCloudName,
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.com",
Audiences: []string{
"https://management.core.windows.net",
"https://management.azure.com",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.azure.com",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.io",
},
}
azureChinaCloud := CloudEnvironment{
Name: "AzureChinaCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.chinacloudapi.cn",
Audiences: []string{
"https://management.core.chinacloudapi.cn",
"https://management.chinacloudapi.cn",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.chinacloudapi.cn",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.cn",
},
}
azureUSGovernment := CloudEnvironment{
Name: "AzureUSGovernment",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.us",
Audiences: []string{
"https://management.core.usgovcloudapi.net",
"https://management.usgovcloudapi.net",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.usgovcloudapi.net",
Suffixes: map[string]string{
AcrSuffixKey: "azurecr.us",
},
}
azureGermanCloud := CloudEnvironment{
Name: "AzureGermanCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.microsoftonline.de",
Audiences: []string{
"https://management.core.cloudapi.de",
"https://management.microsoftazure.de",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.microsoftazure.de",
// There is no separate container registry suffix for German cloud
Suffixes: map[string]string{},
}
ces.cloudEnvironments = map[string]CloudEnvironment{
azurePublicCloud.Name: azurePublicCloud,
azureChinaCloud.Name: azureChinaCloud,
azureUSGovernment.Name: azureUSGovernment,
azureGermanCloud.Name: azureGermanCloud,
}
}
// GetTenantQueryURL returns an URL that can be used to fetch the list of Azure Active Directory tenants from a given cloud environment
func (ce *CloudEnvironment) GetTenantQueryURL() string {
tenantURL := ce.ResourceManagerURL + "/tenants?api-version=2019-11-01"
return tenantURL
}
// GetTokenScope returns a token scope that fits Docker CLI Azure management API usage
func (ce *CloudEnvironment) GetTokenScope() string {
scope := "offline_access " + ce.ResourceManagerURL + "/.default"
return scope
}
// GetAuthorizeRequestFormat returns a string format that can be used to construct authorization code request in an OAuth2 flow.
// The URL uses login endpoint appropriate for given cloud environment.
func (ce *CloudEnvironment) GetAuthorizeRequestFormat() string {
authorizeFormat := ce.Authentication.LoginEndpoint + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
return authorizeFormat
}
// GetTokenRequestFormat returns a string format that can be used to construct a security token request against Azure Active Directory
func (ce *CloudEnvironment) GetTokenRequestFormat() string {
tokenEndpoint := ce.Authentication.LoginEndpoint + "/%s/oauth2/v2.0/token"
return tokenEndpoint
}
func (ce *CloudEnvironment) normalizeURLs() {
ce.ResourceManagerURL = removeTrailingSlash(ce.ResourceManagerURL)
ce.Authentication.LoginEndpoint = removeTrailingSlash(ce.Authentication.LoginEndpoint)
for i, s := range ce.Authentication.Audiences {
ce.Authentication.Audiences[i] = removeTrailingSlash(s)
}
}
func removeTrailingSlash(s string) string {
return strings.TrimSuffix(s, "/")
}

View File

@ -0,0 +1,187 @@
/*
Copyright 2020 Docker Compose CLI authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package login
import (
"testing"
"gotest.tools/v3/assert"
)
func TestNormalizeCloudEnvironmentURLs(t *testing.T) {
ce := CloudEnvironment{
Name: "SecretCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.here.com/",
Audiences: []string{
"https://audience1",
"https://audience2/",
},
Tenant: "common",
},
ResourceManagerURL: "invalid URL",
}
ce.normalizeURLs()
assert.Equal(t, ce.Authentication.LoginEndpoint, "https://login.here.com")
assert.Equal(t, ce.Authentication.Audiences[0], "https://audience1")
assert.Equal(t, ce.Authentication.Audiences[1], "https://audience2")
}
func TestApplyInvalidCloudMetadataJSON(t *testing.T) {
ce := newCloudEnvironmentService()
bb := []byte(`This isn't really valid JSON`)
err := ce.applyCloudMetadata(bb)
assert.Assert(t, err != nil, "Cloud metadata was invalid, so the application should have failed")
ensureWellKnownCloudMetadata(t, ce)
}
func TestApplyInvalidCloudMetatada(t *testing.T) {
ce := newCloudEnvironmentService()
// No name (moniker) for the cloud
bb := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err := ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "no name")
ensureWellKnownCloudMetadata(t, ce)
// Invalid resource manager URL
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "123"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "invalid resource manager URL")
ensureWellKnownCloudMetadata(t, ce)
// Invalid login endpoint
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "456",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "invalid login endpoint")
ensureWellKnownCloudMetadata(t, ce)
// No audiences
bb = []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [ ],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err = ce.applyCloudMetadata(bb)
assert.ErrorContains(t, err, "no authentication audiences")
ensureWellKnownCloudMetadata(t, ce)
}
func TestApplyCloudMetadata(t *testing.T) {
ce := newCloudEnvironmentService()
bb := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "DockerAzureCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
err := ce.applyCloudMetadata(bb)
assert.NilError(t, err)
env, err := ce.Get("DockerAzureCloud")
assert.NilError(t, err)
assert.Equal(t, env.Authentication.LoginEndpoint, "https://login.docker.com")
ensureWellKnownCloudMetadata(t, ce)
}
func TestDefaultCloudMetadataPresent(t *testing.T) {
ensureWellKnownCloudMetadata(t, CloudEnvironments)
}
func ensureWellKnownCloudMetadata(t *testing.T, ce CloudEnvironmentService) {
// Make sure well-known public cloud information is still available
_, err := ce.Get(AzurePublicCloudName)
assert.NilError(t, err)
_, err = ce.Get("AzureChinaCloud")
assert.NilError(t, err)
_, err = ce.Get("AzureUSGovernment")
assert.NilError(t, err)
}

View File

@ -39,17 +39,17 @@ var (
) )
type apiHelper interface { type apiHelper interface {
queryToken(data url.Values, tenantID string) (azureToken, error) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string) error openAzureLoginPage(redirectURL string, ce CloudEnvironment) error
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
getDeviceCodeFlowToken() (adal.Token, error) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error)
} }
type azureAPIHelper struct{} type azureAPIHelper struct{}
func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) { func (helper azureAPIHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
deviceconfig := auth.NewDeviceFlowConfig(clientID, "common") deviceconfig := auth.NewDeviceFlowConfig(clientID, "common")
deviceconfig.Resource = azureManagementURL deviceconfig.Resource = ce.ResourceManagerURL
spToken, err := deviceconfig.ServicePrincipalToken() spToken, err := deviceconfig.ServicePrincipalToken()
if err != nil { if err != nil {
return adal.Token{}, err return adal.Token{}, err
@ -57,9 +57,9 @@ func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
return spToken.Token(), err return spToken.Token(), err
} }
func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error { func (helper azureAPIHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
state := randomString("", 10) state := randomString("", 10)
authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes) authURL := fmt.Sprintf(ce.GetAuthorizeRequestFormat(), clientID, redirectURL, state, ce.GetTokenScope())
return openbrowser(authURL) return openbrowser(authURL)
} }
@ -81,8 +81,8 @@ func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizati
return bits, res.StatusCode, nil return bits, res.StatusCode, nil
} }
func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) { func (helper azureAPIHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (azureToken, error) {
res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) res, err := http.Post(fmt.Sprintf(ce.GetTokenRequestFormat(), tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
if err != nil { if err != nil {
return azureToken{}, err return azureToken{}, err
} }

View File

@ -23,13 +23,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"strconv"
"time" "time"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure/auth" "github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/Azure/go-autorest/autorest/date"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -38,18 +35,6 @@ import (
//go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
const ( const (
// AcrRegistrySuffix suffix for ACR registry images
AcrRegistrySuffix = ".azurecr.io"
activeDirectoryURL = "https://login.microsoftonline.com"
azureManagementURL = "https://management.core.windows.net/"
azureResouceManagementURL = "https://management.azure.com/"
authorizeFormat = activeDirectoryURL + "/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
tokenEndpoint = activeDirectoryURL + "/%s/oauth2/v2.0/token"
getTenantURL = azureResouceManagementURL + "tenants?api-version=2019-11-01"
// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
// v1 scope like "https://management.azure.com/.default" for ARM access
scopes = "offline_access " + azureResouceManagementURL + ".default"
clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id
) )
@ -73,39 +58,41 @@ type (
) )
// AzureLoginService Service to log into azure and get authentifier for azure APIs // AzureLoginService Service to log into azure and get authentifier for azure APIs
type AzureLoginService struct { type AzureLoginService interface {
tokenStore tokenStore Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error
apiHelper apiHelper LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error
}
// AzureLoginServiceAPI interface for Azure login service
type AzureLoginServiceAPI interface {
LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error
Login(ctx context.Context, requestedTenantID string) error
Logout(ctx context.Context) error Logout(ctx context.Context) error
GetCloudEnvironment() (CloudEnvironment, error)
GetValidToken() (oauth2.Token, string, error)
}
type azureLoginService struct {
tokenStore tokenStore
apiHelper apiHelper
cloudEnvironmentSvc CloudEnvironmentService
} }
const tokenStoreFilename = "dockerAccessToken.json" const tokenStoreFilename = "dockerAccessToken.json"
// NewAzureLoginService creates a NewAzureLoginService // NewAzureLoginService creates a NewAzureLoginService
func NewAzureLoginService() (*AzureLoginService, error) { func NewAzureLoginService() (AzureLoginService, error) {
return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}) return newAzureLoginServiceFromPath(GetTokenStorePath(), azureAPIHelper{}, CloudEnvironments)
} }
func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*AzureLoginService, error) { func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper, ces CloudEnvironmentService) (*azureLoginService, error) {
store, err := newTokenStore(tokenStorePath) store, err := newTokenStore(tokenStorePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &AzureLoginService{ return &azureLoginService{
tokenStore: store, tokenStore: store,
apiHelper: helper, apiHelper: helper,
cloudEnvironmentSvc: ces,
}, nil }, nil
} }
// LoginServicePrincipal login with clientId / clientSecret from a service principal. // LoginServicePrincipal login with clientId / clientSecret from a service principal.
// The resulting token does not include a refresh token // The resulting token does not include a refresh token
func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { func (login *azureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string, cloudEnvironment string) error {
// Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret // Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret
creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID) creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
@ -121,7 +108,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not read service principal token expiry: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "could not read service principal token expiry: %s", err)
} }
loginInfo := TokenInfo{TenantID: tenantID, Token: token} loginInfo := TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: cloudEnvironment}
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@ -130,7 +117,7 @@ func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSec
} }
// Logout remove azure token data // Logout remove azure token data
func (login *AzureLoginService) Logout(ctx context.Context) error { func (login *azureLoginService) Logout(ctx context.Context) error {
err := login.tokenStore.removeData() err := login.tokenStore.removeData()
if os.IsNotExist(err) { if os.IsNotExist(err) {
return errors.New("No Azure login data to be removed") return errors.New("No Azure login data to be removed")
@ -138,8 +125,14 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
return err return err
} }
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error { func (login *azureLoginService) getTenantAndValidateLogin(
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) ctx context.Context,
accessToken string,
refreshToken string,
requestedTenantID string,
ce CloudEnvironment,
) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, ce.GetTenantQueryURL(), fmt.Sprintf("Bearer %s", accessToken))
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
} }
@ -155,11 +148,11 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
if err != nil { if err != nil {
return errors.Wrap(errdefs.ErrLoginFailed, err.Error()) return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
} }
tToken, err := login.refreshToken(refreshToken, tenantID) tToken, err := login.refreshToken(refreshToken, tenantID, ce)
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
} }
loginInfo := TokenInfo{TenantID: tenantID, Token: tToken} loginInfo := TokenInfo{TenantID: tenantID, Token: tToken, CloudEnvironment: ce.Name}
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
@ -168,7 +161,12 @@ func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, a
} }
// Login performs an Azure login through a web browser // Login performs an Azure login through a web browser
func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error { func (login *azureLoginService) Login(ctx context.Context, requestedTenantID string, cloudEnvironment string) error {
ce, err := login.cloudEnvironmentSvc.Get(cloudEnvironment)
if err != nil {
return err
}
queryCh := make(chan localResponse, 1) queryCh := make(chan localResponse, 1)
s, err := NewLocalServer(queryCh) s, err := NewLocalServer(queryCh)
if err != nil { if err != nil {
@ -183,8 +181,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
} }
deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1) deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil { if err = login.apiHelper.openAzureLoginPage(redirectURL, ce); err != nil {
login.startDeviceCodeFlow(deviceCodeFlowCh) login.startDeviceCodeFlow(deviceCodeFlowCh, ce)
} }
select { select {
@ -195,7 +193,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
} }
token := dcft.token token := dcft.token
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
case q := <-queryCh: case q := <-queryCh:
if q.err != nil { if q.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@ -208,14 +206,14 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": code, "code": code,
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
} }
token, err := login.apiHelper.queryToken(data, "organizations") token, err := login.apiHelper.queryToken(ce, data, "organizations")
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
} }
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID, ce)
} }
} }
@ -224,10 +222,10 @@ type deviceCodeFlowResponse struct {
err error err error
} }
func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) { func (login *azureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse, ce CloudEnvironment) {
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication") fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
go func() { go func() {
token, err := login.apiHelper.getDeviceCodeFlowToken() token, err := login.apiHelper.getDeviceCodeFlowToken(ce)
if err != nil { if err != nil {
deviceCodeFlowCh <- deviceCodeFlowResponse{err: err} deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
} }
@ -276,72 +274,58 @@ func spToOAuthToken(token adal.Token) (oauth2.Token, error) {
return oauthToken, nil return oauthToken, nil
} }
// NewAuthorizerFromLogin creates an authorizer based on login access token // GetValidToken returns an access token and associated tenant ID.
func NewAuthorizerFromLogin() (autorest.Authorizer, error) { // Will refresh the token as necessary.
return newAuthorizerFromLoginStorePath(GetTokenStorePath()) func (login *azureLoginService) GetValidToken() (oauth2.Token, string, error) {
}
func newAuthorizerFromLoginStorePath(storeTokenPath string) (autorest.Authorizer, error) {
login, err := newAzureLoginServiceFromPath(storeTokenPath, azureAPIHelper{})
if err != nil {
return nil, err
}
oauthToken, err := login.GetValidToken()
if err != nil {
return nil, errors.Wrap(err, "not logged in to azure, you need to run \"docker login azure\" first")
}
token := adal.Token{
AccessToken: oauthToken.AccessToken,
Type: oauthToken.TokenType,
ExpiresIn: json.Number(strconv.Itoa(int(time.Until(oauthToken.Expiry).Seconds()))),
ExpiresOn: json.Number(strconv.Itoa(int(oauthToken.Expiry.Sub(date.UnixEpoch()).Seconds()))),
RefreshToken: "",
Resource: "",
}
return autorest.NewBearerAuthorizer(&token), nil
}
// GetTenantID returns tenantID for current login
func (login AzureLoginService) GetTenantID() (string, error) {
loginInfo, err := login.tokenStore.readToken() loginInfo, err := login.tokenStore.readToken()
if err != nil { if err != nil {
return "", err return oauth2.Token{}, "", err
}
return loginInfo.TenantID, err
}
// GetValidToken returns an access token. Refresh token if needed
func (login *AzureLoginService) GetValidToken() (oauth2.Token, error) {
loginInfo, err := login.tokenStore.readToken()
if err != nil {
return oauth2.Token{}, err
} }
token := loginInfo.Token token := loginInfo.Token
if token.Valid() {
return token, nil
}
tenantID := loginInfo.TenantID tenantID := loginInfo.TenantID
token, err = login.refreshToken(token.RefreshToken, tenantID) if token.Valid() {
if err != nil { return token, tenantID, nil
return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
} }
err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token})
ce, err := login.cloudEnvironmentSvc.Get(loginInfo.CloudEnvironment)
if err != nil { if err != nil {
return oauth2.Token{}, err return oauth2.Token{}, "", errors.Wrap(err, "access token request failed--cloud environment could not be determined.")
} }
return token, nil
token, err = login.refreshToken(token.RefreshToken, tenantID, ce)
if err != nil {
return oauth2.Token{}, "", errors.Wrap(err, "access token request failed. Maybe you need to login to Azure again.")
}
err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token, CloudEnvironment: ce.Name})
if err != nil {
return oauth2.Token{}, "", err
}
return token, tenantID, nil
} }
func (login *AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) { // GeCloudEnvironment returns the cloud environment associated with the current authentication token (if we have one)
func (login *azureLoginService) GetCloudEnvironment() (CloudEnvironment, error) {
tokenInfo, err := login.tokenStore.readToken()
if err != nil {
return CloudEnvironment{}, err
}
cloudEnvironment, err := login.cloudEnvironmentSvc.Get(tokenInfo.CloudEnvironment)
if err != nil {
return CloudEnvironment{}, err
}
return cloudEnvironment, nil
}
func (login *azureLoginService) refreshToken(currentRefreshToken string, tenantID string, ce CloudEnvironment) (oauth2.Token, error) {
data := url.Values{ data := url.Values{
"grant_type": []string{"refresh_token"}, "grant_type": []string{"refresh_token"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"refresh_token": []string{currentRefreshToken}, "refresh_token": []string{currentRefreshToken},
} }
token, err := login.apiHelper.queryToken(data, tenantID) token, err := login.apiHelper.queryToken(ce, data, tenantID)
if err != nil { if err != nil {
return oauth2.Token{}, err return oauth2.Token{}, err
} }

View File

@ -21,10 +21,12 @@ import (
"errors" "errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -36,7 +38,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, error) { func testLoginService(t *testing.T, apiHelperMock *MockAzureHelper, cloudEnvironmentSvc CloudEnvironmentService) (*azureLoginService, error) {
dir, err := ioutil.TempDir("", "test_store") dir, err := ioutil.TempDir("", "test_store")
if err != nil { if err != nil {
return nil, err return nil, err
@ -44,20 +46,45 @@ func testLoginService(t *testing.T, m *MockAzureHelper) (*AzureLoginService, err
t.Cleanup(func() { t.Cleanup(func() {
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}) })
return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), m)
ces := CloudEnvironments
if cloudEnvironmentSvc != nil {
ces = cloudEnvironmentSvc
}
return newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), apiHelperMock, ces)
} }
func TestRefreshInValidToken(t *testing.T) { func TestRefreshInValidToken(t *testing.T) {
data := refreshTokenData("refreshToken") data := url.Values{
m := &MockAzureHelper{} "grant_type": []string{"refresh_token"},
m.On("queryToken", data, "123456").Return(azureToken{ "client_id": []string{clientID},
"scope": []string{"offline_access https://management.docker.com/.default"},
"refresh_token": []string{"refreshToken"},
}
helperMock := &MockAzureHelper{}
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "123456").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken", AccessToken: "newAccessToken",
ExpiresIn: 3600, ExpiresIn: 3600,
Foci: "1", Foci: "1",
}, nil) }, nil)
azureLogin, err := testLoginService(t, m) cloudEnvironmentSvcMock := &MockCloudEnvironmentService{}
cloudEnvironmentSvcMock.On("Get", "AzureDockerCloud").Return(CloudEnvironment{
Name: "AzureDockerCloud",
Authentication: CloudEnvironmentAuthentication{
LoginEndpoint: "https://login.docker.com",
Audiences: []string{
"https://management.docker.com",
"https://management-ext.docker.com",
},
Tenant: "common",
},
ResourceManagerURL: "https://management.docker.com",
Suffixes: map[string]string{},
}, nil)
azureLogin, err := testLoginService(t, helperMock, cloudEnvironmentSvcMock)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{ err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456", TenantID: "123456",
@ -67,33 +94,29 @@ func TestRefreshInValidToken(t *testing.T) {
Expiry: time.Now().Add(-1 * time.Hour), Expiry: time.Now().Add(-1 * time.Hour),
TokenType: "Bearer", TokenType: "Bearer",
}, },
CloudEnvironment: "AzureDockerCloud",
}) })
assert.NilError(t, err) assert.NilError(t, err)
token, _ := azureLogin.GetValidToken() token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, tenantID, "123456")
assert.Equal(t, token.AccessToken, "newAccessToken") assert.Equal(t, token.AccessToken, "newAccessToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(token.Expiry)) assert.Assert(t, time.Now().Add(3500*time.Second).Before(token.Expiry))
storedToken, _ := azureLogin.tokenStore.readToken() storedToken, err := azureLogin.tokenStore.readToken()
assert.NilError(t, err)
assert.Equal(t, storedToken.Token.AccessToken, "newAccessToken") assert.Equal(t, storedToken.Token.AccessToken, "newAccessToken")
assert.Equal(t, storedToken.Token.RefreshToken, "newRefreshToken") assert.Equal(t, storedToken.Token.RefreshToken, "newRefreshToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(storedToken.Token.Expiry)) assert.Assert(t, time.Now().Add(3500*time.Second).Before(storedToken.Token.Expiry))
}
func TestClearErrorMessageIfNotAlreadyLoggedIn(t *testing.T) { assert.Equal(t, storedToken.CloudEnvironment, "AzureDockerCloud")
dir, err := ioutil.TempDir("", "test_store")
assert.NilError(t, err)
t.Cleanup(func() {
_ = os.RemoveAll(dir)
})
_, err = newAuthorizerFromLoginStorePath(filepath.Join(dir, tokenStoreFilename))
assert.ErrorContains(t, err, "not logged in to azure, you need to run \"docker login azure\" first")
} }
func TestDoesNotRefreshValidToken(t *testing.T) { func TestDoesNotRefreshValidToken(t *testing.T) {
expiryDate := time.Now().Add(1 * time.Hour) expiryDate := time.Now().Add(1 * time.Hour)
azureLogin, err := testLoginService(t, nil) azureLogin, err := testLoginService(t, nil, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{ err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456", TenantID: "123456",
@ -103,25 +126,55 @@ func TestDoesNotRefreshValidToken(t *testing.T) {
Expiry: expiryDate, Expiry: expiryDate,
TokenType: "Bearer", TokenType: "Bearer",
}, },
CloudEnvironment: AzurePublicCloudName,
}) })
assert.NilError(t, err) assert.NilError(t, err)
token, _ := azureLogin.GetValidToken() token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, token.AccessToken, "accessToken") assert.Equal(t, token.AccessToken, "accessToken")
assert.Equal(t, tenantID, "123456")
}
func TestTokenStoreAssumesAzurePublicCloud(t *testing.T) {
expiryDate := time.Now().Add(1 * time.Hour)
azureLogin, err := testLoginService(t, nil, nil)
assert.NilError(t, err)
err = azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456",
Token: oauth2.Token{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
Expiry: expiryDate,
TokenType: "Bearer",
},
// Simulates upgrade from older version of Docker CLI that did not have cloud environment concept
CloudEnvironment: "",
})
assert.NilError(t, err)
token, tenantID, err := azureLogin.GetValidToken()
assert.NilError(t, err)
assert.Equal(t, tenantID, "123456")
assert.Equal(t, token.AccessToken, "accessToken")
ce, err := azureLogin.GetCloudEnvironment()
assert.NilError(t, err)
assert.Equal(t, ce.Name, AzurePublicCloudName)
} }
func TestInvalidLogin(t *testing.T) { func TestInvalidLogin(t *testing.T) {
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL := args.Get(0).(string) redirectURL := args.Get(0).(string)
err := queryKeyValue(redirectURL, "error", "access denied: login failed") err := queryKeyValue(redirectURL, "error", "access denied: login failed")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "") err = azureLogin.Login(context.TODO(), "", AzurePublicCloudName)
assert.Error(t, err, "no login code: login failed") assert.Error(t, err, "no login code: login failed")
} }
@ -129,19 +182,22 @@ func TestValidLogin(t *testing.T) {
var redirectURL string var redirectURL string
ctx := context.TODO() ctx := context.TODO()
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879") err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{ return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": []string{"123456879"}, "code": []string{"123456879"},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
}) })
}), "organizations").Return(azureToken{ }), "organizations").Return(azureToken{
@ -153,18 +209,18 @@ func TestValidLogin(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken", AccessToken: "newAccessToken",
ExpiresIn: 3600, ExpiresIn: 3600,
Foci: "1", Foci: "1",
}, nil) }, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "") err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -174,24 +230,28 @@ func TestValidLogin(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer") assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
} }
func TestValidLoginRequestedTenant(t *testing.T) { func TestValidLoginRequestedTenant(t *testing.T) {
var redirectURL string var redirectURL string
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879") err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{ return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": []string{"123456879"}, "code": []string{"123456879"},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
}) })
}), "organizations").Return(azureToken{ }), "organizations").Return(azureToken{
@ -205,18 +265,18 @@ func TestValidLoginRequestedTenant(t *testing.T) {
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
ctx := context.TODO() ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken", AccessToken: "newAccessToken",
ExpiresIn: 3600, ExpiresIn: 3600,
Foci: "1", Foci: "1",
}, nil) }, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038") err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -226,24 +286,28 @@ func TestValidLoginRequestedTenant(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer") assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
} }
func TestLoginNoTenant(t *testing.T) { func TestLoginNoTenant(t *testing.T) {
var redirectURL string var redirectURL string
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879") err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{ return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": []string{"123456879"}, "code": []string{"123456879"},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
}) })
}), "organizations").Return(azureToken{ }), "organizations").Return(azureToken{
@ -255,31 +319,34 @@ func TestLoginNoTenant(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038") err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038", AzurePublicCloudName)
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed") assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
} }
func TestLoginRequestedTenantNotFound(t *testing.T) { func TestLoginRequestedTenantNotFound(t *testing.T) {
var redirectURL string var redirectURL string
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879") err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{ return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": []string{"123456879"}, "code": []string{"123456879"},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
}) })
}), "organizations").Return(azureToken{ }), "organizations").Return(azureToken{
@ -291,31 +358,34 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
authBody := `{"value":[]}` authBody := `{"value":[]}`
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "") err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.Error(t, err, "could not find azure tenant: login failed") assert.Error(t, err, "could not find azure tenant: login failed")
} }
func TestLoginAuthorizationFailed(t *testing.T) { func TestLoginAuthorizationFailed(t *testing.T) {
var redirectURL string var redirectURL string
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { ce, err := CloudEnvironments.Get(AzurePublicCloudName)
assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879") err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err) assert.NilError(t, err)
}).Return(nil) }).Return(nil)
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{ return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"code": []string{"123456879"}, "code": []string{"123456879"},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"redirect_uri": []string{redirectURL}, "redirect_uri": []string{redirectURL},
}) })
}), "organizations").Return(azureToken{ }), "organizations").Return(azureToken{
@ -328,35 +398,38 @@ func TestLoginAuthorizationFailed(t *testing.T) {
authBody := `[access denied]` authBody := `[access denied]`
ctx := context.TODO() ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "") err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
} }
func TestValidThroughDeviceCodeFlow(t *testing.T) { func TestValidThroughDeviceCodeFlow(t *testing.T) {
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser")) ce, err := CloudEnvironments.Get(AzurePublicCloudName)
m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil) assert.NilError(t, err)
m.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Return(errors.New("Could not open browser"))
m.On("getDeviceCodeFlowToken", mock.AnythingOfType("CloudEnvironment")).Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
ctx := context.TODO() ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, ce.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken", ce)
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken", AccessToken: "newAccessToken",
ExpiresIn: 3600, ExpiresIn: 3600,
Foci: "1", Foci: "1",
}, nil) }, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m, nil)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(ctx, "") err = azureLogin.Login(ctx, "", AzurePublicCloudName)
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -366,13 +439,110 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.Equal(t, loginToken.Token.Type(), "Bearer") assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureCloud")
} }
func refreshTokenData(refreshToken string) url.Values { func TestNonstandardCloudEnvironment(t *testing.T) {
dockerCloudMetadata := []byte(`
[{
"authentication": {
"loginEndpoint": "https://login.docker.com/",
"audiences": [
"https://management.docker.com/",
"https://management.cli.docker.com/"
],
"tenant": "F5773994-FE88-482E-9E33-6E799D250416"
},
"name": "AzureDockerCloud",
"suffixes": {
"acrLoginServer": "azurecr.docker.io"
},
"resourceManager": "https://management.docker.com/"
}]`)
var metadataReqCount int32 = 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write(dockerCloudMetadata)
assert.NilError(t, err)
atomic.AddInt32(&metadataReqCount, 1)
}))
defer srv.Close()
cloudMetadataURL, cloudMetadataURLSet := os.LookupEnv(CloudMetadataURLVar)
if cloudMetadataURLSet {
defer func() {
err := os.Setenv(CloudMetadataURLVar, cloudMetadataURL)
assert.NilError(t, err)
}()
}
err := os.Setenv(CloudMetadataURLVar, srv.URL)
assert.NilError(t, err)
ctx := context.TODO()
ces := newCloudEnvironmentService()
ces.cloudMetadataURL = srv.URL
dockerCloudEnv, err := ces.Get("AzureDockerCloud")
assert.NilError(t, err)
helperMock := &MockAzureHelper{}
var redirectURL string
helperMock.On("openAzureLoginPage", mock.AnythingOfType("string"), mock.AnythingOfType("CloudEnvironment")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
err := queryKeyValue(redirectURL, "code", "123456879")
assert.NilError(t, err)
}).Return(nil)
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), mock.MatchedBy(func(data url.Values) bool {
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
return reflect.DeepEqual(data, url.Values{
"grant_type": []string{"authorization_code"},
"client_id": []string{clientID},
"code": []string{"123456879"},
"scope": []string{dockerCloudEnv.GetTokenScope()},
"redirect_uri": []string{redirectURL},
})
}), "organizations").Return(azureToken{
RefreshToken: "firstRefreshToken",
AccessToken: "firstAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
authBody := `{"value":[{"id":"/tenants/F5773994-FE88-482E-9E33-6E799D250416","tenantId":"F5773994-FE88-482E-9E33-6E799D250416"}]}`
helperMock.On("queryAPIWithHeader", ctx, dockerCloudEnv.GetTenantQueryURL(), "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken", dockerCloudEnv)
helperMock.On("queryToken", mock.AnythingOfType("login.CloudEnvironment"), data, "F5773994-FE88-482E-9E33-6E799D250416").Return(azureToken{
RefreshToken: "newRefreshToken",
AccessToken: "newAccessToken",
ExpiresIn: 3600,
Foci: "1",
}, nil)
azureLogin, err := testLoginService(t, helperMock, ces)
assert.NilError(t, err)
err = azureLogin.Login(ctx, "", "AzureDockerCloud")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
assert.NilError(t, err)
assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
assert.Equal(t, loginToken.TenantID, "F5773994-FE88-482E-9E33-6E799D250416")
assert.Equal(t, loginToken.Token.Type(), "Bearer")
assert.Equal(t, loginToken.CloudEnvironment, "AzureDockerCloud")
assert.Equal(t, metadataReqCount, int32(1))
}
// Don't warn about refreshToken parameter taking the same value for all invocations
// nolint:unparam
func refreshTokenData(refreshToken string, ce CloudEnvironment) url.Values {
return url.Values{ return url.Values{
"grant_type": []string{"refresh_token"}, "grant_type": []string{"refresh_token"},
"client_id": []string{clientID}, "client_id": []string{clientID},
"scope": []string{scopes}, "scope": []string{ce.GetTokenScope()},
"refresh_token": []string{refreshToken}, "refresh_token": []string{refreshToken},
} }
} }
@ -394,13 +564,13 @@ type MockAzureHelper struct {
mock.Mock mock.Mock
} }
func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) { func (s *MockAzureHelper) getDeviceCodeFlowToken(ce CloudEnvironment) (adal.Token, error) {
args := s.Called() args := s.Called(ce)
return args.Get(0).(adal.Token), args.Error(1) return args.Get(0).(adal.Token), args.Error(1)
} }
func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) { func (s *MockAzureHelper) queryToken(ce CloudEnvironment, data url.Values, tenantID string) (token azureToken, err error) {
args := s.Called(data, tenantID) args := s.Called(ce, data, tenantID)
return args.Get(0).(azureToken), args.Error(1) return args.Get(0).(azureToken), args.Error(1)
} }
@ -409,7 +579,16 @@ func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationU
return args.Get(0).([]byte), args.Int(1), args.Error(2) return args.Get(0).([]byte), args.Int(1), args.Error(2)
} }
func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error { func (s *MockAzureHelper) openAzureLoginPage(redirectURL string, ce CloudEnvironment) error {
args := s.Called(redirectURL) args := s.Called(redirectURL, ce)
return args.Error(0) return args.Error(0)
} }
type MockCloudEnvironmentService struct {
mock.Mock
}
func (s *MockCloudEnvironmentService) Get(name string) (CloudEnvironment, error) {
args := s.Called(name)
return args.Get(0).(CloudEnvironment), args.Error(1)
}

View File

@ -34,8 +34,9 @@ type tokenStore struct {
// TokenInfo data stored in tokenStore // TokenInfo data stored in tokenStore
type TokenInfo struct { type TokenInfo struct {
Token oauth2.Token `json:"oauthToken"` Token oauth2.Token `json:"oauthToken"`
TenantID string `json:"tenantId"` TenantID string `json:"tenantId"`
CloudEnvironment string `json:"cloudEnvironment"`
} }
func newTokenStore(path string) (tokenStore, error) { func newTokenStore(path string) (tokenStore, error) {
@ -82,6 +83,9 @@ func (store tokenStore) readToken() (TokenInfo, error) {
if err := json.Unmarshal(bytes, &loginInfo); err != nil { if err := json.Unmarshal(bytes, &loginInfo); err != nil {
return TokenInfo{}, err return TokenInfo{}, err
} }
if loginInfo.CloudEnvironment == "" {
loginInfo.CloudEnvironment = AzurePublicCloudName
}
return loginInfo, nil return loginInfo, nil
} }

View File

@ -40,6 +40,7 @@ func AzureLoginCommand() *cobra.Command {
flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use") flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use")
flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login") flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login")
flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login") flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login")
flags.StringVar(&opts.CloudName, "cloud-name", "", "Name of a registered Azure cloud")
return cmd return cmd
} }