diff --git a/aci/login/helper.go b/aci/login/helper.go index df6402b05..21739cc54 100644 --- a/aci/login/helper.go +++ b/aci/login/helper.go @@ -27,6 +27,9 @@ import ( "runtime" "strings" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/pkg/errors" ) @@ -37,18 +40,29 @@ var ( type apiHelper interface { queryToken(data url.Values, tenantID string) (azureToken, error) openAzureLoginPage(redirectURL string) error - queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) + queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) + getDeviceCodeFlowToken() (adal.Token, error) } type azureAPIHelper struct{} +func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) { + deviceconfig := auth.NewDeviceFlowConfig(clientID, "common") + deviceconfig.Resource = "https://management.core.windows.net/" + spToken, err := deviceconfig.ServicePrincipalToken() + if err != nil { + return adal.Token{}, err + } + return spToken.Token(), err +} + func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error { state := randomString("", 10) authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes) return openbrowser(authURL) } -func (helper azureAPIHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) { +func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { req, err := http.NewRequest(http.MethodGet, authorizationURL, nil) if err != nil { return nil, 0, err diff --git a/aci/login/login.go b/aci/login/login.go index 51d3a79de..775ad065b 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -28,7 +28,7 @@ import ( "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" - auth2 "github.com/Azure/go-autorest/autorest/azure/auth" + "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/Azure/go-autorest/autorest/date" "github.com/pkg/errors" "golang.org/x/oauth2" @@ -38,9 +38,9 @@ import ( //go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff const ( - authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s" - tokenEndpoint = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" - authorizationURL = "https://management.azure.com/tenants?api-version=2019-11-01" + authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s" + tokenEndpoint = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" + getTenantURL = "https://management.azure.com/tenants?api-version=2019-11-01" // scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token // v1 scope like "https://management.azure.com/.default" for ARM access scopes = "offline_access https://management.azure.com/.default" @@ -101,7 +101,7 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*Azu // The resulting token does not include a refresh token func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { // Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret - creds := auth2.NewClientCredentialsConfig(clientID, clientSecret, tenantID) + creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID) spToken, err := creds.ServicePrincipalToken() if err != nil { @@ -132,6 +132,35 @@ func (login *AzureLoginService) Logout(ctx context.Context) error { return err } +func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error { + bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) + if err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) + } + + if statusCode != http.StatusOK { + return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits) + } + var t tenantResult + if err := json.Unmarshal(bits, &t); err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err) + } + tenantID, err := getTenantID(t.Value, requestedTenantID) + if err != nil { + return errors.Wrap(errdefs.ErrLoginFailed, err.Error()) + } + tToken, err := login.refreshToken(refreshToken, tenantID) + if err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err) + } + loginInfo := TokenInfo{TenantID: tenantID, Token: tToken} + + if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) + } + return nil +} + // Login performs an Azure login through a web browser func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error { queryCh := make(chan localResponse, 1) @@ -148,7 +177,12 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str } if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil { - return err + fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication") + token, err := login.apiHelper.getDeviceCodeFlowToken() + if err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) + } + return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) } select { @@ -173,36 +207,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) } - - bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken)) - if err != nil { - return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) - } - - switch statusCode { - case http.StatusOK: - var t tenantResult - if err := json.Unmarshal(bits, &t); err != nil { - return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err) - } - tenantID, err := getTenantID(t.Value, requestedTenantID) - if err != nil { - return errors.Wrap(errdefs.ErrLoginFailed, err.Error()) - } - tToken, err := login.refreshToken(token.RefreshToken, tenantID) - if err != nil { - return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err) - } - loginInfo := TokenInfo{TenantID: tenantID, Token: tToken} - - if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil { - return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err) - } - default: - return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits) - } + return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) } - return nil } func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) { diff --git a/aci/login/login_test.go b/aci/login/login_test.go index ca39ceeb2..2762ebb80 100644 --- a/aci/login/login_test.go +++ b/aci/login/login_test.go @@ -18,6 +18,7 @@ package login import ( "context" + "errors" "io/ioutil" "net/http" "net/url" @@ -27,6 +28,8 @@ import ( "testing" "time" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/stretchr/testify/mock" "gotest.tools/v3/assert" @@ -113,7 +116,7 @@ func TestInvalidLogin(t *testing.T) { redirectURL := args.Get(0).(string) err := queryKeyValue(redirectURL, "error", "access denied: login failed") assert.NilError(t, err) - }) + }).Return(nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) @@ -129,7 +132,7 @@ func TestValidLogin(t *testing.T) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) - }) + }).Return(nil) m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage @@ -149,7 +152,7 @@ func TestValidLogin(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -179,7 +182,7 @@ func TestValidLoginRequestedTenant(t *testing.T) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) - }) + }).Return(nil) m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage @@ -200,7 +203,7 @@ func TestValidLoginRequestedTenant(t *testing.T) { authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"}, {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -230,7 +233,7 @@ func TestLoginNoTenant(t *testing.T) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) - }) + }).Return(nil) m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage @@ -249,7 +252,7 @@ func TestLoginNoTenant(t *testing.T) { }, nil) authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) @@ -265,7 +268,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) - }) + }).Return(nil) m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage @@ -284,7 +287,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) { }, nil) authBody := `{"value":[]}` - m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) @@ -300,7 +303,7 @@ func TestLoginAuthorizationFailed(t *testing.T) { redirectURL = args.Get(0).(string) err := queryKeyValue(redirectURL, "code", "123456879") assert.NilError(t, err) - }) + }).Return(nil) m.On("queryToken", mock.MatchedBy(func(data url.Values) bool { //Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage @@ -320,7 +323,7 @@ func TestLoginAuthorizationFailed(t *testing.T) { authBody := `[access denied]` - m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) @@ -329,6 +332,36 @@ func TestLoginAuthorizationFailed(t *testing.T) { assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") } +func TestValidThroughDeviceCodeFlow(t *testing.T) { + m := &MockAzureHelper{} + m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser")) + m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil) + + authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` + + m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + data := refreshTokenData("firstRefreshToken") + m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ + RefreshToken: "newRefreshToken", + AccessToken: "newAccessToken", + ExpiresIn: 3600, + Foci: "1", + }, nil) + azureLogin, err := testLoginService(t, m) + assert.NilError(t, err) + + err = azureLogin.Login(context.TODO(), "") + assert.NilError(t, err) + + loginToken, err := azureLogin.tokenStore.readToken() + assert.NilError(t, err) + assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken") + assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken") + assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry)) + assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038") + assert.Equal(t, loginToken.Token.Type(), "Bearer") +} + func refreshTokenData(refreshToken string) url.Values { return url.Values{ "grant_type": []string{"refresh_token"}, @@ -355,17 +388,22 @@ type MockAzureHelper struct { mock.Mock } +func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) { + args := s.Called() + return args.Get(0).(adal.Token), args.Error(1) +} + func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) { args := s.Called(data, tenantID) return args.Get(0).(azureToken), args.Error(1) } -func (s *MockAzureHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) { +func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { args := s.Called(authorizationURL, authorizationHeader) return args.Get(0).([]byte), args.Int(1), args.Error(2) } func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error { - s.Called(redirectURL) - return nil + args := s.Called(redirectURL) + return args.Error(0) }