mirror of
https://github.com/docker/compose.git
synced 2025-07-08 22:34:26 +02:00
Pass in context to login tenant query, so it gets cancelled if the user Ctrl+C
Signed-off-by: Guillaume Tardif <guillaume.tardif@docker.com>
This commit is contained in:
parent
81f2496b5e
commit
76c92a8359
@ -17,6 +17,7 @@
|
|||||||
package login
|
package login
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -40,7 +41,7 @@ var (
|
|||||||
type apiHelper interface {
|
type apiHelper interface {
|
||||||
queryToken(data url.Values, tenantID string) (azureToken, error)
|
queryToken(data url.Values, tenantID string) (azureToken, error)
|
||||||
openAzureLoginPage(redirectURL string) error
|
openAzureLoginPage(redirectURL string) error
|
||||||
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
||||||
getDeviceCodeFlowToken() (adal.Token, error)
|
getDeviceCodeFlowToken() (adal.Token, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
|
|||||||
return openbrowser(authURL)
|
return openbrowser(authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
|
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
req.Header.Add("Authorization", authorizationHeader)
|
req.Header.Add("Authorization", authorizationHeader)
|
||||||
res, err := http.DefaultClient.Do(req)
|
res, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
|
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
|
||||||
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
|
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
|
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
|
||||||
}
|
}
|
||||||
@ -189,7 +189,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
|
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
|
||||||
}
|
}
|
||||||
token := dcft.token
|
token := dcft.token
|
||||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||||
case q := <-queryCh:
|
case q := <-queryCh:
|
||||||
if q.err != nil {
|
if q.err != nil {
|
||||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
|
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
|
||||||
@ -209,7 +209,7 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
|
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
|
||||||
}
|
}
|
||||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
|
|||||||
|
|
||||||
func TestValidLogin(t *testing.T) {
|
func TestValidLogin(t *testing.T) {
|
||||||
var redirectURL string
|
var redirectURL string
|
||||||
|
ctx := context.TODO()
|
||||||
m := &MockAzureHelper{}
|
m := &MockAzureHelper{}
|
||||||
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
|
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
|
||||||
redirectURL = args.Get(0).(string)
|
redirectURL = args.Get(0).(string)
|
||||||
@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
|
|||||||
|
|
||||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||||
|
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||||
data := refreshTokenData("firstRefreshToken")
|
data := refreshTokenData("firstRefreshToken")
|
||||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||||
RefreshToken: "newRefreshToken",
|
RefreshToken: "newRefreshToken",
|
||||||
@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
|
|||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "")
|
err = azureLogin.Login(ctx, "")
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
loginToken, err := azureLogin.tokenStore.readToken()
|
loginToken, err := azureLogin.tokenStore.readToken()
|
||||||
@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||||||
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
|
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
|
||||||
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||||
|
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
ctx := context.TODO()
|
||||||
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||||
data := refreshTokenData("firstRefreshToken")
|
data := refreshTokenData("firstRefreshToken")
|
||||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||||
RefreshToken: "newRefreshToken",
|
RefreshToken: "newRefreshToken",
|
||||||
@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
|
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
loginToken, err := azureLogin.tokenStore.readToken()
|
loginToken, err := azureLogin.tokenStore.readToken()
|
||||||
@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
|
|||||||
Foci: "1",
|
Foci: "1",
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||||
|
|
||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
|
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
|
||||||
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
|
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
|
|||||||
Foci: "1",
|
Foci: "1",
|
||||||
}, nil)
|
}, nil)
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
authBody := `{"value":[]}`
|
authBody := `{"value":[]}`
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||||
|
|
||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "")
|
err = azureLogin.Login(ctx, "")
|
||||||
assert.Error(t, err, "could not find azure tenant: login failed")
|
assert.Error(t, err, "could not find azure tenant: login failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
|
|||||||
|
|
||||||
authBody := `[access denied]`
|
authBody := `[access denied]`
|
||||||
|
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
ctx := context.TODO()
|
||||||
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
||||||
|
|
||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "")
|
err = azureLogin.Login(ctx, "")
|
||||||
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
|
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
|
|||||||
|
|
||||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||||
|
|
||||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
ctx := context.TODO()
|
||||||
|
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||||
data := refreshTokenData("firstRefreshToken")
|
data := refreshTokenData("firstRefreshToken")
|
||||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||||
RefreshToken: "newRefreshToken",
|
RefreshToken: "newRefreshToken",
|
||||||
@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
|
|||||||
azureLogin, err := testLoginService(t, m)
|
azureLogin, err := testLoginService(t, m)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
err = azureLogin.Login(context.TODO(), "")
|
err = azureLogin.Login(ctx, "")
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
loginToken, err := azureLogin.tokenStore.readToken()
|
loginToken, err := azureLogin.tokenStore.readToken()
|
||||||
@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
|
|||||||
return args.Get(0).(azureToken), args.Error(1)
|
return args.Get(0).(azureToken), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||||
args := s.Called(authorizationURL, authorizationHeader)
|
args := s.Called(ctx, authorizationURL, authorizationHeader)
|
||||||
return args.Get(0).([]byte), args.Int(1), args.Error(2)
|
return args.Get(0).([]byte), args.Int(1), args.Error(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user