Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C

Signed-off-by: Guillaume Tardif <guillaume.tardif@docker.com>
This commit is contained in:
Guillaume Tardif 2020-09-03 18:21:30 +02:00
parent 81f2496b5e
commit 76c92a8359
3 changed files with 28 additions and 20 deletions

View File

@ -17,6 +17,7 @@
package login package login
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -40,7 +41,7 @@ var (
type apiHelper interface { type apiHelper interface {
queryToken(data url.Values, tenantID string) (azureToken, error) queryToken(data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string) error openAzureLoginPage(redirectURL string) error
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
getDeviceCodeFlowToken() (adal.Token, error) getDeviceCodeFlowToken() (adal.Token, error)
} }
@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
return openbrowser(authURL) return openbrowser(authURL)
} }
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil) req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
req = req.WithContext(ctx)
req.Header.Add("Authorization", authorizationHeader) req.Header.Add("Authorization", authorizationHeader)
res, err := http.DefaultClient.Do(req) res, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {

View File

@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
return err return err
} }
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error { func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken)) bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
} }
@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
} }
token := dcft.token token := dcft.token
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
case q := <-queryCh: case q := <-queryCh:
if q.err != nil { if q.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
if err != nil { if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err) return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
} }
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID) return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
} }
} }

View File

@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
func TestValidLogin(t *testing.T) { func TestValidLogin(t *testing.T) {
var redirectURL string var redirectURL string
ctx := context.TODO()
m := &MockAzureHelper{} m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "") err = azureLogin.Login(ctx, "")
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"}, authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` {"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038") err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
Foci: "1", Foci: "1",
}, nil) }, nil)
ctx := context.TODO()
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038") err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed") assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
} }
@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
Foci: "1", Foci: "1",
}, nil) }, nil)
ctx := context.TODO()
authBody := `{"value":[]}` authBody := `{"value":[]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "") err = azureLogin.Login(ctx, "")
assert.Error(t, err, "could not find azure tenant: login failed") assert.Error(t, err, "could not find azure tenant: login failed")
} }
@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
authBody := `[access denied]` authBody := `[access denied]`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil) ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "") err = azureLogin.Login(ctx, "")
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed") assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
} }
@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}` authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil) ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken") data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{ m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
azureLogin, err := testLoginService(t, m) azureLogin, err := testLoginService(t, m)
assert.NilError(t, err) assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "") err = azureLogin.Login(ctx, "")
assert.NilError(t, err) assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken() loginToken, err := azureLogin.tokenStore.readToken()
@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
return args.Get(0).(azureToken), args.Error(1) return args.Get(0).(azureToken), args.Error(1)
} }
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) { func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
args := s.Called(authorizationURL, authorizationHeader) args := s.Called(ctx, authorizationURL, authorizationHeader)
return args.Get(0).([]byte), args.Int(1), args.Error(2) return args.Get(0).([]byte), args.Int(1), args.Error(2)
} }