From 76c92a835939de09bf97cac9df2e04d232ac25e7 Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Thu, 3 Sep 2020 18:21:30 +0200 Subject: [PATCH] Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C Signed-off-by: Guillaume Tardif --- aci/login/helper.go | 6 ++++-- aci/login/login.go | 8 ++++---- aci/login/login_test.go | 34 ++++++++++++++++++++-------------- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/aci/login/helper.go b/aci/login/helper.go index de6d43520..8e7a00418 100644 --- a/aci/login/helper.go +++ b/aci/login/helper.go @@ -17,6 +17,7 @@ package login import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -40,7 +41,7 @@ var ( type apiHelper interface { queryToken(data url.Values, tenantID string) (azureToken, error) openAzureLoginPage(redirectURL string) error - queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) + queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) getDeviceCodeFlowToken() (adal.Token, error) } @@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error { return openbrowser(authURL) } -func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { +func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) { req, err := http.NewRequest(http.MethodGet, authorizationURL, nil) if err != nil { return nil, 0, err } + req = req.WithContext(ctx) req.Header.Add("Authorization", authorizationHeader) res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/aci/login/login.go b/aci/login/login.go index 433b41144..da9474265 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error { return err } -func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error { - bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) +func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error { + bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) } @@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) } token := dcft.token - return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) case q := <-queryCh: if q.err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) @@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str if err != nil { return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) } - return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) + return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID) } } diff --git a/aci/login/login_test.go b/aci/login/login_test.go index 2762ebb80..e57f6e4a5 100644 --- a/aci/login/login_test.go +++ b/aci/login/login_test.go @@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) { func TestValidLogin(t *testing.T) { var redirectURL string + ctx := context.TODO() m := &MockAzureHelper{} m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { redirectURL = args.Get(0).(string) @@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) { authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"}, {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) { Foci: "1", }, nil) + ctx := context.TODO() authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038") + err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038") assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed") } @@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) { Foci: "1", }, nil) + ctx := context.TODO() authBody := `{"value":[]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.Error(t, err, "could not find azure tenant: login failed") } @@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) { authBody := `[access denied]` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") } @@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) { authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` - m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) + ctx := context.TODO() + m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) data := refreshTokenData("firstRefreshToken") m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ RefreshToken: "newRefreshToken", @@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) { azureLogin, err := testLoginService(t, m) assert.NilError(t, err) - err = azureLogin.Login(context.TODO(), "") + err = azureLogin.Login(ctx, "") assert.NilError(t, err) loginToken, err := azureLogin.tokenStore.readToken() @@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az return args.Get(0).(azureToken), args.Error(1) } -func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { - args := s.Called(authorizationURL, authorizationHeader) +func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) { + args := s.Called(ctx, authorizationURL, authorizationHeader) return args.Get(0).([]byte), args.Int(1), args.Error(2) }