From bba9e055afa7f2e29430444963bf2ff1f72abbc5 Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Wed, 1 Jul 2020 12:25:18 +0200 Subject: [PATCH] Allow users to specify tenanted when logging into azure (if several tenants for azure account) --- azure/backend.go | 2 +- azure/login/login.go | 30 +++++++++--- azure/login/login_test.go | 92 +++++++++++++++++++++++++++++++++++-- cli/cmd/login/azurelogin.go | 28 +++++++++++ cli/cmd/login/login.go | 14 ++---- 5 files changed, 146 insertions(+), 20 deletions(-) create mode 100644 cli/cmd/login/azurelogin.go diff --git a/azure/backend.go b/azure/backend.go index 20406cf59..2aad88827 100644 --- a/azure/backend.go +++ b/azure/backend.go @@ -338,7 +338,7 @@ type aciCloudService struct { } func (cs *aciCloudService) Login(ctx context.Context, params map[string]string) error { - return cs.loginService.Login(ctx) + return cs.loginService.Login(ctx, params[login.TenantIDLoginParam]) } func (cs *aciCloudService) CreateContextData(ctx context.Context, params map[string]string) (interface{}, string, error) { diff --git a/azure/login/login.go b/azure/login/login.go index 3b1e5f546..54a4e5aa6 100644 --- a/azure/login/login.go +++ b/azure/login/login.go @@ -49,6 +49,9 @@ const ( // v1 scope like "https://management.azure.com/.default" for ARM access scopes = "offline_access https://management.azure.com/.default" clientID = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" // Azure CLI client id + + // TenantIDLoginParam + TenantIDLoginParam = "tenantId" ) type ( @@ -121,7 +124,7 @@ func (login AzureLoginService) TestLoginFromServicePrincipal(clientID string, cl } // Login performs an Azure login through a web browser -func (login AzureLoginService) Login(ctx context.Context) error { +func (login AzureLoginService) Login(ctx context.Context, requestedTenantID string) error { queryCh := make(chan localResponse, 1) s, err := NewLocalServer(queryCh) if err != nil { @@ -170,15 +173,15 @@ func (login AzureLoginService) Login(ctx context.Context) error { if err := json.Unmarshal(bits, &t); err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err) } - if len(t.Value) < 1 { - return errors.Wrap(errdefs.ErrLoginFailed, "could not find azure tenant") + tenantID, err := getTenantID(t.Value, requestedTenantID) + if err != nil { + return errors.Wrap(errdefs.ErrLoginFailed, err.Error()) } - tID := t.Value[0].TenantID - tToken, err := login.refreshToken(token.RefreshToken, tID) + tToken, err := login.refreshToken(token.RefreshToken, tenantID) if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err) } - loginInfo := TokenInfo{TenantID: tID, Token: tToken} + loginInfo := TokenInfo{TenantID: tenantID, Token: tToken} if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) @@ -190,6 +193,21 @@ func (login AzureLoginService) Login(ctx context.Context) error { return nil } +func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) { + if requestedTenantID == "" { + if len(tenantValues) < 1 { + return "", errors.Errorf("could not find azure tenant") + } + return tenantValues[0].TenantID, nil + } + for _, tValue := range tenantValues { + if tValue.TenantID == requestedTenantID { + return tValue.TenantID, nil + } + } + return "", errors.Errorf("could not find requested azure tenant %s", requestedTenantID) +} + func getTokenStorePath() string { cliPath, _ := cli.AccessTokensPath() return filepath.Join(filepath.Dir(cliPath), tokenStoreFilename) diff --git a/azure/login/login_test.go b/azure/login/login_test.go index 8b80cbd7b..012fa1590 100644 --- a/azure/login/login_test.go +++ b/azure/login/login_test.go @@ -125,7 +125,7 @@ func (suite *LoginSuite) TestInvalidLogin() { azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) Expect(err).To(BeNil()) - err = azureLogin.Login(context.TODO()) + err = azureLogin.Login(context.TODO(), "") Expect(err.Error()).To(BeEquivalentTo("no login code: login failed")) } @@ -166,7 +166,57 @@ func (suite *LoginSuite) TestValidLogin() { azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) Expect(err).To(BeNil()) - err = azureLogin.Login(context.TODO()) + err = azureLogin.Login(context.TODO(), "") + Expect(err).To(BeNil()) + + loginToken, err := suite.azureLogin.tokenStore.readToken() + Expect(err).To(BeNil()) + Expect(loginToken.Token.AccessToken).To(Equal("newAccessToken")) + Expect(loginToken.Token.RefreshToken).To(Equal("newRefreshToken")) + Expect(loginToken.Token.Expiry).To(BeTemporally(">", time.Now().Add(3500*time.Second))) + Expect(loginToken.TenantID).To(Equal("12345a7c-c56d-43e8-9549-dd230ce8a038")) + Expect(loginToken.Token.Type()).To(Equal("Bearer")) +} + +func (suite *LoginSuite) TestValidLoginRequestedTenant() { + var redirectURL string + suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + redirectURL = args.Get(0).(string) + err := queryKeyValue(redirectURL, "code", "123456879") + Expect(err).To(BeNil()) + }) + + suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage + return reflect.DeepEqual(data, url.Values{ + "grant_type": []string{"authorization_code"}, + "client_id": []string{clientID}, + "code": []string{"123456879"}, + "scope": []string{scopes}, + "redirect_uri": []string{redirectURL}, + }) + }), "organizations").Return(azureToken{ + RefreshToken: "firstRefreshToken", + AccessToken: "firstAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + + authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"}, + {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` + + suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken") + suite.mockHelper.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ + RefreshToken: "newRefreshToken", + AccessToken: "newAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) + Expect(err).To(BeNil()) + + err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038") Expect(err).To(BeNil()) loginToken, err := suite.azureLogin.tokenStore.readToken() @@ -202,13 +252,47 @@ func (suite *LoginSuite) TestLoginNoTenant() { Foci: "1", }, nil) + authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` + suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + + azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) + Expect(err).To(BeNil()) + + err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038") + Expect(err.Error()).To(BeEquivalentTo("could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")) +} + +func (suite *LoginSuite) TestLoginRequestedTenantNotFound() { + var redirectURL string + suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { + redirectURL = args.Get(0).(string) + err := queryKeyValue(redirectURL, "code", "123456879") + Expect(err).To(BeNil()) + }) + + suite.mockHelper.On("queryToken", mock.MatchedBy(func(data url.Values) bool { + //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage + return reflect.DeepEqual(data, url.Values{ + "grant_type": []string{"authorization_code"}, + "client_id": []string{clientID}, + "code": []string{"123456879"}, + "scope": []string{scopes}, + "redirect_uri": []string{redirectURL}, + }) + }), "organizations").Return(azureToken{ + RefreshToken: "firstRefreshToken", + AccessToken: "firstAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + authBody := `{"value":[]}` suite.mockHelper.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) Expect(err).To(BeNil()) - err = azureLogin.Login(context.TODO()) + err = azureLogin.Login(context.TODO(), "") Expect(err.Error()).To(BeEquivalentTo("could not find azure tenant: login failed")) } @@ -243,7 +327,7 @@ func (suite *LoginSuite) TestLoginAuthorizationFailed() { azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) Expect(err).To(BeNil()) - err = azureLogin.Login(context.TODO()) + err = azureLogin.Login(context.TODO(), "") Expect(err.Error()).To(BeEquivalentTo("unable to login status code 400: [access denied]: login failed")) } diff --git a/cli/cmd/login/azurelogin.go b/cli/cmd/login/azurelogin.go new file mode 100644 index 000000000..8c7e7824f --- /dev/null +++ b/cli/cmd/login/azurelogin.go @@ -0,0 +1,28 @@ +package login + +import ( + "github.com/spf13/cobra" + + "github.com/docker/api/azure/login" +) + +type azureLoginOpts struct { + tenantID string +} + +// AzureLoginCommand returns the azure login command +func AzureLoginCommand() *cobra.Command { + opts := azureLoginOpts{} + cmd := &cobra.Command{ + Use: "azure", + Short: "Log in to azure", + Args: cobra.MaximumNArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + return cloudLogin(cmd, "aci", map[string]string{login.TenantIDLoginParam: opts.tenantID}) + }, + } + flags := cmd.Flags() + flags.StringVar(&opts.tenantID, "tenant-id", "", "Specify tenant ID to use from your azure account") + + return cmd +} diff --git a/cli/cmd/login/login.go b/cli/cmd/login/login.go index 1cff03ac3..6ab8c8ffe 100644 --- a/cli/cmd/login/login.go +++ b/cli/cmd/login/login.go @@ -34,7 +34,7 @@ import ( // Command returns the login command func Command() *cobra.Command { cmd := &cobra.Command{ - Use: "login [OPTIONS] [SERVER] | login azure", + Use: "login [OPTIONS] [SERVER]", Short: "Log in to a Docker registry", Long: "Log in to a Docker registry or cloud backend.\nIf no registry server is specified, the default is defined by the daemon.", Args: cobra.MaximumNArgs(1), @@ -47,29 +47,25 @@ func Command() *cobra.Command { flags.BoolP("password-stdin", "", false, "Take the password from stdin") mobyflags.AddMobyFlagsForRetrocompatibility(flags) + cmd.AddCommand(AzureLoginCommand()) return cmd } func runLogin(cmd *cobra.Command, args []string) error { if len(args) == 1 && !strings.Contains(args[0], ".") { backend := args[0] - switch backend { - case "azure": - return cloudLogin(cmd, "aci") - default: - return errors.New("unknown backend type for cloud login: " + backend) - } + return errors.New("unknown backend type for cloud login: " + backend) } return mobycli.ExecCmd(cmd) } -func cloudLogin(cmd *cobra.Command, backendType string) error { +func cloudLogin(cmd *cobra.Command, backendType string, params map[string]string) error { ctx := cmd.Context() cs, err := client.GetCloudService(ctx, backendType) if err != nil { return errors.Wrap(errdefs.ErrLoginFailed, "cannot connect to backend") } - err = cs.Login(ctx, nil) + err = cs.Login(ctx, params) if errors.Is(err, context.Canceled) { return errors.New("login canceled") }