Merge pull request #577 from docker/aci_device_login_ctrlc

ACI: Allow Ctrl+C to cancel CLI when using Azure Device Code Flow login
This commit is contained in:
Guillaume Tardif 2020-09-04 13:07:54 +02:00 committed by GitHub
commit cbb416976a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 25 deletions

View File

@ -17,6 +17,7 @@
package login
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
@ -40,7 +41,7 @@ var (
type apiHelper interface {
queryToken(data url.Values, tenantID string) (azureToken, error)
openAzureLoginPage(redirectURL string) error
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error)
getDeviceCodeFlowToken() (adal.Token, error)
}
@ -62,11 +63,12 @@ func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
return openbrowser(authURL)
}
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
func (helper azureAPIHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
if err != nil {
return nil, 0, err
}
req = req.WithContext(ctx)
req.Header.Add("Authorization", authorizationHeader)
res, err := http.DefaultClient.Do(req)
if err != nil {

View File

@ -132,8 +132,8 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
return err
}
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
func (login *AzureLoginService) getTenantAndValidateLogin(ctx context.Context, accessToken string, refreshToken string, requestedTenantID string) error {
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(ctx, getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
}
@ -176,18 +176,20 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
return errors.Wrap(errdefs.ErrLoginFailed, "empty redirect URL")
}
deviceCodeFlowCh := make(chan deviceCodeFlowResponse, 1)
if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
token, err := login.apiHelper.getDeviceCodeFlowToken()
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
}
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
login.startDeviceCodeFlow(deviceCodeFlowCh)
}
select {
case <-ctx.Done():
return ctx.Err()
case dcft := <-deviceCodeFlowCh:
if dcft.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
}
token := dcft.token
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
case q := <-queryCh:
if q.err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "unhandled local login server error: %s", err)
@ -207,10 +209,26 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
if err != nil {
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
}
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
return login.getTenantAndValidateLogin(ctx, token.AccessToken, token.RefreshToken, requestedTenantID)
}
}
type deviceCodeFlowResponse struct {
token adal.Token
err error
}
func (login *AzureLoginService) startDeviceCodeFlow(deviceCodeFlowCh chan deviceCodeFlowResponse) {
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
go func() {
token, err := login.apiHelper.getDeviceCodeFlowToken()
if err != nil {
deviceCodeFlowCh <- deviceCodeFlowResponse{err: err}
}
deviceCodeFlowCh <- deviceCodeFlowResponse{token: token}
}()
}
func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {
if requestedTenantID == "" {
if len(tenantValues) < 1 {

View File

@ -127,6 +127,7 @@ func TestInvalidLogin(t *testing.T) {
func TestValidLogin(t *testing.T) {
var redirectURL string
ctx := context.TODO()
m := &MockAzureHelper{}
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string)
@ -152,7 +153,7 @@ func TestValidLogin(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -163,7 +164,7 @@ func TestValidLogin(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -203,7 +204,8 @@ func TestValidLoginRequestedTenant(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -214,7 +216,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "12345a7c-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "12345a7c-c56d-43e8-9549-dd230ce8a038")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -251,13 +253,14 @@ func TestLoginNoTenant(t *testing.T) {
Foci: "1",
}, nil)
ctx := context.TODO()
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "00000000-c56d-43e8-9549-dd230ce8a038")
err = azureLogin.Login(ctx, "00000000-c56d-43e8-9549-dd230ce8a038")
assert.Error(t, err, "could not find requested azure tenant 00000000-c56d-43e8-9549-dd230ce8a038: login failed")
}
@ -286,13 +289,14 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
Foci: "1",
}, nil)
ctx := context.TODO()
authBody := `{"value":[]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.Error(t, err, "could not find azure tenant: login failed")
}
@ -323,12 +327,13 @@ func TestLoginAuthorizationFailed(t *testing.T) {
authBody := `[access denied]`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
}
@ -339,7 +344,8 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
ctx := context.TODO()
m.On("queryAPIWithHeader", ctx, getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
data := refreshTokenData("firstRefreshToken")
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
RefreshToken: "newRefreshToken",
@ -350,7 +356,7 @@ func TestValidThroughDeviceCodeFlow(t *testing.T) {
azureLogin, err := testLoginService(t, m)
assert.NilError(t, err)
err = azureLogin.Login(context.TODO(), "")
err = azureLogin.Login(ctx, "")
assert.NilError(t, err)
loginToken, err := azureLogin.tokenStore.readToken()
@ -398,8 +404,8 @@ func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token az
return args.Get(0).(azureToken), args.Error(1)
}
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
args := s.Called(authorizationURL, authorizationHeader)
func (s *MockAzureHelper) queryAPIWithHeader(ctx context.Context, authorizationURL string, authorizationHeader string) ([]byte, int, error) {
args := s.Called(ctx, authorizationURL, authorizationHeader)
return args.Get(0).([]byte), args.Int(1), args.Error(2)
}