mirror of https://github.com/docker/compose.git
Merge pull request #577 from docker/aci_device_login_ctrlc
ACI: Allow Ctrl+C to cancel CLI when using Azure Device Code Flow login
This commit is contained in:
commit
cbb416976a
|
@ -17,6 +17,7 @@
|
|||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -40,7 +41,7 @@ var (
|
|||
type apiHelper interface {
|
||||
queryToken(data url.Values, tenantID string) (azureToken, error)
|
||||
openAzureLoginPage(redirectURL string) error
|
||||
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
||||
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
||||
getDeviceCodeFlowToken() (adal.Token, error)
|
||||
}
|
||||
|
||||
|
@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
|
|||
return openbrowser(authURL)
|
||||
}
|
||||
|
||||
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Add("Authorization", authorizationHeader)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
|
|
|
@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
|
||||
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
|
||||
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
|
||||
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
|
||||
}
|
||||
|
@ -176,18 +176,20 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||
return errors.Wrap(errdefs.ErrLoginFailed, "empty redirect URL")
|
||||
}
|
||||
|
||||
deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
|
||||
if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
|
||||
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
|
||||
token, err := login.apiHelper.getDeviceCodeFlowToken()
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
|
||||
}
|
||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
login.startDeviceCodeFlow(deviceCodeFlowCh)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case dcft := <-deviceCodeFlowCh:
|
||||
if dcft.err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
|
||||
}
|
||||
token := dcft.token
|
||||
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
case q := <-queryCh:
|
||||
if q.err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
|
||||
|
@ -207,10 +209,26 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
|
||||
}
|
||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
}
|
||||
}
|
||||
|
||||
type deviceCodeFlowResponse struct {
|
||||
token adal.Token
|
||||
err error
|
||||
}
|
||||
|
||||
func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) {
|
||||
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
|
||||
go func() {
|
||||
token, err := login.apiHelper.getDeviceCodeFlowToken()
|
||||
if err != nil {
|
||||
deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
|
||||
}
|
||||
deviceCodeFlowCh <- deviceCodeFlowResponse{token: token}
|
||||
}()
|
||||
}
|
||||
|
||||
func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {
|
||||
if requestedTenantID == "" {
|
||||
if len(tenantValues) < 1 {
|
||||
|
|
|
@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
|
|||
|
||||
func TestValidLogin(t *testing.T) {
|
||||
var redirectURL string
|
||||
ctx := context.TODO()
|
||||
m := &MockAzureHelper{}
|
||||
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
|
||||
redirectURL = args.Get(0).(string)
|
||||
|
@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
|
|||
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
|
@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
|
|||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "")
|
||||
err = azureLogin.Login(ctx, "")
|
||||
assert.NilError(t, err)
|
||||
|
||||
loginToken, err := azureLogin.tokenStore.readToken()
|
||||
|
@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
|
||||
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
ctx := context.TODO()
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
|
@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
|
||||
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
|
||||
assert.NilError(t, err)
|
||||
|
||||
loginToken, err := azureLogin.tokenStore.readToken()
|
||||
|
@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
|
|||
Foci: "1",
|
||||
}, nil)
|
||||
|
||||
ctx := context.TODO()
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
|
||||
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
|
||||
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
|
||||
}
|
||||
|
||||
|
@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
|
|||
Foci: "1",
|
||||
}, nil)
|
||||
|
||||
ctx := context.TODO()
|
||||
authBody := `{"value":[]}`
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "")
|
||||
err = azureLogin.Login(ctx, "")
|
||||
assert.Error(t, err, "could not find azure tenant: login failed")
|
||||
}
|
||||
|
||||
|
@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
|
|||
|
||||
authBody := `[access denied]`
|
||||
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
||||
ctx := context.TODO()
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "")
|
||||
err = azureLogin.Login(ctx, "")
|
||||
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
|
||||
}
|
||||
|
||||
|
@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
|
|||
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
ctx := context.TODO()
|
||||
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
|
@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
|
|||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "")
|
||||
err = azureLogin.Login(ctx, "")
|
||||
assert.NilError(t, err)
|
||||
|
||||
loginToken, err := azureLogin.tokenStore.readToken()
|
||||
|
@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
|
|||
return args.Get(0).(azureToken), args.Error(1)
|
||||
}
|
||||
|
||||
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
args := s.Called(authorizationURL, authorizationHeader)
|
||||
func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
args := s.Called(ctx, authorizationURL, authorizationHeader)
|
||||
return args.Get(0).([]byte), args.Int(1), args.Error(2)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue