From 81f2496b5eb263f58ca01ea17dbf0cac14ef5f0c Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Thu, 3 Sep 2020 15:35:33 +0200 Subject: [PATCH 1/2] Allow Ctrl+C to cancel CLI when using Azure Device Code Flow login Signed-off-by: Guillaume Tardif --- aci/login/login.go | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/aci/login/login.go b/aci/login/login.go index 775ad065b..433b41144 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -176,18 +176,20 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str return errors.Wrap(errdefs.ErrLoginFailed, "empty redirect URL") } + deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1) if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil { - fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication") - token, err := login.apiHelper.getDeviceCodeFlowToken() - if err != nil { - return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) - } - return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) + login.startDeviceCodeFlow(deviceCodeFlowCh) } select { case <-ctx.Done(): return ctx.Err() + case dcft := <-deviceCodeFlowCh: + if dcft.err != nil { + return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) + } + token := dcft.token + return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) case q := <-queryCh: if q.err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) @@ -211,6 +213,22 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str } } +type deviceCodeFlowResponse struct { + token adal.Token + err error +} + +func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) { + fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication") + go func() { + token, err := login.apiHelper.getDeviceCodeFlowToken() + if err != nil { + deviceCodeFlowCh <- deviceCodeFlowResponse{err: err} + } + deviceCodeFlowCh <- deviceCodeFlowResponse{token: token} + }() +} + func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) { if requestedTenantID == "" { if len(tenantValues) < 1 { From 76c92a835939de09bf97cac9df2e04d232ac25e7 Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Thu, 3 Sep 2020 18:21:30 +0200 Subject: [PATCH 2/2] Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C Signed-off-by: Guillaume Tardif --- aci/login/helper.go | 6 ++++-- aci/login/login.go | 8 ++++---- aci/login/login_test.go | 34 ++++++++++++++++++++-------------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/aci/login/helper.go b/aci/login/helper.go index de6d43520..8e7a00418 100644 --- a/aci/login/helper.go +++ b/aci/login/helper.go @@ -17,6 +17,7 @@ package login import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -40,7 +41,7 @@ var ( type apiHelper interface { queryToken(data url.Values, tenantID string) (azureToken, error) openAzureLoginPage(redirectURL string) error - queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) + queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) getDeviceCodeFlowToken() (adal.Token, error) } @@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error { return openbrowser(authURL) } -func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { +func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) { req, err := http.NewRequest(http.MethodGet, authorizationURL, nil) if err != nil { return nil, 0, err } + req = req.WithContext(ctx) req.Header.Add("Authorization", authorizationHeader) res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/aci/login/login.go b/aci/login/login.go index 433b41144..da9474265 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error { return err } -func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error { - bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) +func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error { + bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) } @@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) } token := dcft.token - return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) case q := <-queryCh: if q.err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) @@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) } - return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) } } diff --git a/aci/login/login_test.go b/aci/login/login_test.go index 2762ebb80..e57f6e4a5 100644 --- a/aci/login/login_test.go +++ b/aci/login/login_test.go @@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) { func TestValidLogin(t *testing.T) { var redirectURL string + ctx := context.TODO() m := &MockAzureHelper{} m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) @@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) { authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"}, {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) { Foci: "1", }, nil) + ctx := context.TODO() authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038") assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed") } @@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) { Foci: "1", }, nil) + ctx := context.TODO() authBody := `{"value":[]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.Error(t, err, "could not find azure tenant: login failed") } @@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) { authBody := `[access denied]` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") } @@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az return args.Get(0).(azureToken), args.Error(1) } -func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { - args := s.Called(authorizationURL, authorizationHeader) +func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) { + args := s.Called(ctx, authorizationURL, authorizationHeader) return args.Get(0).([]byte), args.Int(1), args.Error(2) }