Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C

Signed-off-by: Guillaume Tardif <guillaume.tardif@docker.com>
This commit is contained in:
Guillaume Tardif 2020-09-03 18:21:30 +02:00
parent 81f2496b5e
commit 76c92a8359
3 changed files with 28 additions and 20 deletions

View File

@ -17,6 +17,7 @@
package login
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
@ -40,7 +41,7 @@ var (
type apiHelper interface {
queryToken(data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string) error
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
getDeviceCodeFlowToken() (adal.Token, error)
}
@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
return openbrowser(authURL)
}
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
if err != nil {
return nil, 0, err
}
req = req.WithContext(ctx)
req.Header.Add("Authorization", authorizationHeader)
res, err := http.DefaultClient.Do(req)
if err != nil {

View File

@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
return err
}
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
}
@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
}
token := dcft.token
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
case q := <-queryCh:
if q.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
}
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
}
}

View File

@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
func TestValidLogin(t *testing.T) {
var redirectURL string
ctx := context.TODO()
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
Foci: "1",
}, nil)
ctx := context.TODO()
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
}
@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
Foci: "1",
}, nil)
ctx := context.TODO()
authBody := `{"value":[]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.Error(t, err, "could not find azure tenant: login failed")
}
@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
authBody := `[access denied]`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
}
@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
return args.Get(0).(azureToken), args.Error(1)
}
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
args := s.Called(authorizationURL, authorizationHeader)
func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
args := s.Called(ctx, authorizationURL, authorizationHeader)
return args.Get(0).([]byte), args.Int(1), args.Error(2)
}