mirror of https://github.com/docker/compose.git
Merge pull request #575 from docker/aci_login_fallback
Azure fallback to device code flow if we can’t open a browser
This commit is contained in:
commit
9b0dd5d8cd
|
@ -27,6 +27,9 @@ import (
|
|||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -37,18 +40,29 @@ var (
|
|||
type apiHelper interface {
|
||||
queryToken(data url.Values, tenantID string) (azureToken, error)
|
||||
openAzureLoginPage(redirectURL string) error
|
||||
queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
||||
queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error)
|
||||
getDeviceCodeFlowToken() (adal.Token, error)
|
||||
}
|
||||
|
||||
type azureAPIHelper struct{}
|
||||
|
||||
func (helper azureAPIHelper) getDeviceCodeFlowToken() (adal.Token, error) {
|
||||
deviceconfig := auth.NewDeviceFlowConfig(clientID, "common")
|
||||
deviceconfig.Resource = "https://management.core.windows.net/"
|
||||
spToken, err := deviceconfig.ServicePrincipalToken()
|
||||
if err != nil {
|
||||
return adal.Token{}, err
|
||||
}
|
||||
return spToken.Token(), err
|
||||
}
|
||||
|
||||
func (helper azureAPIHelper) openAzureLoginPage(redirectURL string) error {
|
||||
state := randomString("", 10)
|
||||
authURL := fmt.Sprintf(authorizeFormat, clientID, redirectURL, state, scopes)
|
||||
return openbrowser(authURL)
|
||||
}
|
||||
|
||||
func (helper azureAPIHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
func (helper azureAPIHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, authorizationURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
@ -88,13 +102,13 @@ func openbrowser(address string) error {
|
|||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if isWsl() {
|
||||
return exec.Command("wslview", address).Start()
|
||||
return exec.Command("wslview", address).Run()
|
||||
}
|
||||
return exec.Command("xdg-open", address).Start()
|
||||
return exec.Command("xdg-open", address).Run()
|
||||
case "windows":
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", address).Start()
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", address).Run()
|
||||
case "darwin":
|
||||
return exec.Command("open", address).Start()
|
||||
return exec.Command("open", address).Run()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform")
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ import (
|
|||
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
auth2 "github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"github.com/Azure/go-autorest/autorest/azure/auth"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -38,9 +38,9 @@ import (
|
|||
|
||||
//go login process, derived from code sample provided by MS at https://github.com/devigned/go-az-cli-stuff
|
||||
const (
|
||||
authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
|
||||
tokenEndpoint = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
||||
authorizationURL = "https://management.azure.com/tenants?api-version=2019-11-01"
|
||||
authorizeFormat = "https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&state=%s&prompt=select_account&response_mode=query&scope=%s"
|
||||
tokenEndpoint = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
||||
getTenantURL = "https://management.azure.com/tenants?api-version=2019-11-01"
|
||||
// scopes for a multi-tenant app works for openid, email, other common scopes, but fails when trying to add a token
|
||||
// v1 scope like "https://management.azure.com/.default" for ARM access
|
||||
scopes = "offline_access https://management.azure.com/.default"
|
||||
|
@ -101,7 +101,7 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*Azu
|
|||
// The resulting token does not include a refresh token
|
||||
func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error {
|
||||
// Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret
|
||||
creds := auth2.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
|
||||
creds := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID)
|
||||
|
||||
spToken, err := creds.ServicePrincipalToken()
|
||||
if err != nil {
|
||||
|
@ -132,6 +132,35 @@ func (login *AzureLoginService) Logout(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (login *AzureLoginService) getTenantAndValidateLogin(accessToken string, refreshToken string, requestedTenantID string) error {
|
||||
bits, statusCode, err := login.apiHelper.queryAPIWithHeader(getTenantURL, fmt.Sprintf("Bearer %s", accessToken))
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
|
||||
}
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits)
|
||||
}
|
||||
var t tenantResult
|
||||
if err := json.Unmarshal(bits, &t); err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err)
|
||||
}
|
||||
tenantID, err := getTenantID(t.Value, requestedTenantID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
|
||||
}
|
||||
tToken, err := login.refreshToken(refreshToken, tenantID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
|
||||
}
|
||||
loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
|
||||
|
||||
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Login performs an Azure login through a web browser
|
||||
func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID string) error {
|
||||
queryCh := make(chan localResponse, 1)
|
||||
|
@ -148,7 +177,12 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||
}
|
||||
|
||||
if err = login.apiHelper.openAzureLoginPage(redirectURL); err != nil {
|
||||
return err
|
||||
fmt.Println("Could not automatically open a browser, falling back to Azure device code flow authentication")
|
||||
token, err := login.apiHelper.getDeviceCodeFlowToken()
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not get token using device code flow: %s", err)
|
||||
}
|
||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
}
|
||||
|
||||
select {
|
||||
|
@ -173,36 +207,8 @@ func (login *AzureLoginService) Login(ctx context.Context, requestedTenantID str
|
|||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "access token request failed: %s", err)
|
||||
}
|
||||
|
||||
bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "check auth failed: %s", err)
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case http.StatusOK:
|
||||
var t tenantResult
|
||||
if err := json.Unmarshal(bits, &t); err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to unmarshal tenant: %s", err)
|
||||
}
|
||||
tenantID, err := getTenantID(t.Value, requestedTenantID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
|
||||
}
|
||||
tToken, err := login.refreshToken(token.RefreshToken, tenantID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to refresh token: %s", err)
|
||||
}
|
||||
loginInfo := TokenInfo{TenantID: tenantID, Token: tToken}
|
||||
|
||||
if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "could not store login info: %s", err)
|
||||
}
|
||||
default:
|
||||
return errors.Wrapf(errdefs.ErrLoginFailed, "unable to login status code %d: %s", statusCode, bits)
|
||||
}
|
||||
return login.getTenantAndValidateLogin(token.AccessToken, token.RefreshToken, requestedTenantID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTenantID(tenantValues []tenantValue, requestedTenantID string) (string, error) {
|
||||
|
|
|
@ -18,6 +18,7 @@ package login
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -27,6 +28,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/go-autorest/autorest/adal"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
"gotest.tools/v3/assert"
|
||||
|
||||
|
@ -113,7 +116,7 @@ func TestInvalidLogin(t *testing.T) {
|
|||
redirectURL := args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "error", "access denied: login failed")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
@ -129,7 +132,7 @@ func TestValidLogin(t *testing.T) {
|
|||
redirectURL = args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "code", "123456879")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
|
||||
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
|
||||
|
@ -149,7 +152,7 @@ func TestValidLogin(t *testing.T) {
|
|||
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
|
@ -179,7 +182,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||
redirectURL = args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "code", "123456879")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
|
||||
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
|
||||
|
@ -200,7 +203,7 @@ func TestValidLoginRequestedTenant(t *testing.T) {
|
|||
authBody := `{"value":[{"id":"/tenants/00000000-c56d-43e8-9549-dd230ce8a038","tenantId":"00000000-c56d-43e8-9549-dd230ce8a038"},
|
||||
{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
|
@ -230,7 +233,7 @@ func TestLoginNoTenant(t *testing.T) {
|
|||
redirectURL = args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "code", "123456879")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
|
||||
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
|
||||
|
@ -249,7 +252,7 @@ func TestLoginNoTenant(t *testing.T) {
|
|||
}, nil)
|
||||
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
@ -265,7 +268,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
|
|||
redirectURL = args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "code", "123456879")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
|
||||
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
|
||||
|
@ -284,7 +287,7 @@ func TestLoginRequestedTenantNotFound(t *testing.T) {
|
|||
}, nil)
|
||||
|
||||
authBody := `{"value":[]}`
|
||||
m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
@ -300,7 +303,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
|
|||
redirectURL = args.Get(0).(string)
|
||||
err := queryKeyValue(redirectURL, "code", "123456879")
|
||||
assert.NilError(t, err)
|
||||
})
|
||||
}).Return(nil)
|
||||
|
||||
m.On("queryToken", mock.MatchedBy(func(data url.Values) bool {
|
||||
//Need a matcher here because the value of redirectUrl is not known until executing openAzureLoginPage
|
||||
|
@ -320,7 +323,7 @@ func TestLoginAuthorizationFailed(t *testing.T) {
|
|||
|
||||
authBody := `[access denied]`
|
||||
|
||||
m.On("queryAuthorizationAPI", authorizationURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 400, nil)
|
||||
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
@ -329,6 +332,36 @@ func TestLoginAuthorizationFailed(t *testing.T) {
|
|||
assert.Error(t, err, "unable to login status code 400: [access denied]: login failed")
|
||||
}
|
||||
|
||||
func TestValidThroughDeviceCodeFlow(t *testing.T) {
|
||||
m := &MockAzureHelper{}
|
||||
m.On("openAzureLoginPage", mock.AnythingOfType("string")).Return(errors.New("Could not open browser"))
|
||||
m.On("getDeviceCodeFlowToken").Return(adal.Token{AccessToken: "firstAccessToken", RefreshToken: "firstRefreshToken"}, nil)
|
||||
|
||||
authBody := `{"value":[{"id":"/tenants/12345a7c-c56d-43e8-9549-dd230ce8a038","tenantId":"12345a7c-c56d-43e8-9549-dd230ce8a038"}]}`
|
||||
|
||||
m.On("queryAPIWithHeader", getTenantURL, "Bearer firstAccessToken").Return([]byte(authBody), 200, nil)
|
||||
data := refreshTokenData("firstRefreshToken")
|
||||
m.On("queryToken", data, "12345a7c-c56d-43e8-9549-dd230ce8a038").Return(azureToken{
|
||||
RefreshToken: "newRefreshToken",
|
||||
AccessToken: "newAccessToken",
|
||||
ExpiresIn: 3600,
|
||||
Foci: "1",
|
||||
}, nil)
|
||||
azureLogin, err := testLoginService(t, m)
|
||||
assert.NilError(t, err)
|
||||
|
||||
err = azureLogin.Login(context.TODO(), "")
|
||||
assert.NilError(t, err)
|
||||
|
||||
loginToken, err := azureLogin.tokenStore.readToken()
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, loginToken.Token.AccessToken, "newAccessToken")
|
||||
assert.Equal(t, loginToken.Token.RefreshToken, "newRefreshToken")
|
||||
assert.Assert(t, time.Now().Add(3500*time.Second).Before(loginToken.Token.Expiry))
|
||||
assert.Equal(t, loginToken.TenantID, "12345a7c-c56d-43e8-9549-dd230ce8a038")
|
||||
assert.Equal(t, loginToken.Token.Type(), "Bearer")
|
||||
}
|
||||
|
||||
func refreshTokenData(refreshToken string) url.Values {
|
||||
return url.Values{
|
||||
"grant_type": []string{"refresh_token"},
|
||||
|
@ -355,17 +388,22 @@ type MockAzureHelper struct {
|
|||
mock.Mock
|
||||
}
|
||||
|
||||
func (s *MockAzureHelper) getDeviceCodeFlowToken() (adal.Token, error) {
|
||||
args := s.Called()
|
||||
return args.Get(0).(adal.Token), args.Error(1)
|
||||
}
|
||||
|
||||
func (s *MockAzureHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) {
|
||||
args := s.Called(data, tenantID)
|
||||
return args.Get(0).(azureToken), args.Error(1)
|
||||
}
|
||||
|
||||
func (s *MockAzureHelper) queryAuthorizationAPI(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
func (s *MockAzureHelper) queryAPIWithHeader(authorizationURL string, authorizationHeader string) ([]byte, int, error) {
|
||||
args := s.Called(authorizationURL, authorizationHeader)
|
||||
return args.Get(0).([]byte), args.Int(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (s *MockAzureHelper) openAzureLoginPage(redirectURL string) error {
|
||||
s.Called(redirectURL)
|
||||
return nil
|
||||
args := s.Called(redirectURL)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue