Azure: Refactor creation of login server

Signed-off-by: Christopher Crone <christopher.crone@docker.com>
This commit is contained in:
Christopher Crone 2020-05-22 13:38:30 +02:00
parent dae54a3c1f
commit 35789ace12
7 changed files with 177 additions and 131 deletions

View File

@ -280,5 +280,8 @@ type aciCloudService struct {
} }
func (cs *aciCloudService) Login(ctx context.Context, params map[string]string) error { func (cs *aciCloudService) Login(ctx context.Context, params map[string]string) error {
return cs.loginService.Login(ctx) if err := cs.loginService.Login(ctx); err != nil {
return errors.Wrap(errdefs.ErrLoginFailed, err.Error())
}
return nil
} }

119
azure/login/local_server.go Normal file
View File

@ -0,0 +1,119 @@
package login
import (
"fmt"
"net"
"net/http"
"net/url"
"github.com/pkg/errors"
)
const failHTML = `
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>Login failed</title>
</head>
<body>
<h4>Some failures occurred during the authentication</h4>
<p>You can log an issue at <a href="https://github.com/azure/azure-cli/issues">Azure CLI GitHub Repository</a> and we will assist you in resolving it.</p>
</body>
</html>
`
const successHTML = `
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta http-equiv="refresh" content="10;url=https://docs.microsoft.com/cli/azure/">
<title>Login successfully</title>
</head>
<body>
<h4>You have logged into Microsoft Azure!</h4>
<p>You can close this window, or we will redirect you to the <a href="https://docs.microsoft.com/cli/azure/">Azure CLI documents</a> in 10 seconds.</p>
</body>
</html>
`
const (
// redirectHost is where the user's browser is redirected on authentication
// completion. Note that only registered hosts can be used. i.e.:
// "localhost" works but "127.0.0.1" does not.
redirectHost = "localhost"
)
type localResponse struct {
values url.Values
err error
}
// LocalServer is an Azure login server
type LocalServer struct {
httpServer *http.Server
listener net.Listener
queryCh chan localResponse
}
// NewLocalServer creates an Azure login server
func NewLocalServer(queryCh chan localResponse) (*LocalServer, error) {
s := &LocalServer{queryCh: queryCh}
mux := http.NewServeMux()
mux.HandleFunc("/", s.handler())
s.httpServer = &http.Server{Handler: mux}
l, err := net.Listen("tcp", redirectHost+":0")
if err != nil {
return nil, err
}
s.listener = l
p := l.Addr().(*net.TCPAddr).Port
if p == 0 {
return nil, errors.New("unable to allocate login server port")
}
return s, nil
}
// Serve starts the local Azure login server
func (s *LocalServer) Serve() {
go func() {
if err := s.httpServer.Serve(s.listener); err != nil {
s.queryCh <- localResponse{
err: errors.Wrap(err, "unable to start login server"),
}
}
}()
}
// Close stops the local Azure login server
func (s *LocalServer) Close() {
_ = s.httpServer.Close()
}
// Addr returns the address that the local Azure server is service to
func (s *LocalServer) Addr() string {
return fmt.Sprintf("http://%s:%d", redirectHost, s.listener.Addr().(*net.TCPAddr).Port)
}
func (s *LocalServer) handler() func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if _, hasCode := r.URL.Query()["code"]; hasCode {
if _, err := w.Write([]byte(successHTML)); err != nil {
s.queryCh <- localResponse{
err: errors.Wrap(err, "unable to write success HTML"),
}
} else {
s.queryCh <- localResponse{values: r.URL.Query()}
}
} else {
if _, err := w.Write([]byte(failHTML)); err != nil {
s.queryCh <- localResponse{
err: errors.Wrap(err, "unable to write fail HTML"),
}
} else {
s.queryCh <- localResponse{values: r.URL.Query()}
}
}
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@ -12,8 +13,6 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/docker/api/errdefs"
"github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure/cli" "github.com/Azure/go-autorest/autorest/azure/cli"
@ -77,28 +76,32 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (Azur
}, nil }, nil
} }
//Login perform azure login through browser // Login performs an Azure login through a web browser
func (login AzureLoginService) Login(ctx context.Context) error { func (login AzureLoginService) Login(ctx context.Context) error {
queryCh := make(chan url.Values, 1) queryCh := make(chan localResponse, 1)
serverPort, err := startLoginServer(queryCh) s, err := NewLocalServer(queryCh)
if err != nil { if err != nil {
return err return err
} }
s.Serve()
defer s.Close()
redirectURL := "http://localhost:" + strconv.Itoa(serverPort) redirectURL := s.Addr()
if redirectURL == "" {
return errors.New("empty redirect URL")
}
login.apiHelper.openAzureLoginPage(redirectURL) login.apiHelper.openAzureLoginPage(redirectURL)
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
case qsValues := <-queryCh: case q := <-queryCh:
errorMsg, hasError := qsValues["error"] if q.err != nil {
if hasError { return errors.Wrap(err, "unhandled local login server error")
return fmt.Errorf("login failed : %s", errorMsg)
} }
code, hasCode := qsValues["code"] code, hasCode := q.values["code"]
if !hasCode { if !hasCode {
return errdefs.ErrLoginFailed return errors.New("no login code")
} }
data := url.Values{ data := url.Values{
"grant_type": []string{"authorization_code"}, "grant_type": []string{"authorization_code"},
@ -109,38 +112,35 @@ func (login AzureLoginService) Login(ctx context.Context) error {
} }
token, err := login.apiHelper.queryToken(data, "organizations") token, err := login.apiHelper.queryToken(data, "organizations")
if err != nil { if err != nil {
return errors.Wrap(err, "Access token request failed") return errors.Wrap(err, "access token request failed")
} }
bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken)) bits, statusCode, err := login.apiHelper.queryAuthorizationAPI(authorizationURL, fmt.Sprintf("Bearer %s", token.AccessToken))
if err != nil { if err != nil {
return errors.Wrap(err, "login failed") return errors.Wrap(err, "check auth failed")
} }
if statusCode == 200 { switch statusCode {
var tenantResult tenantResult case http.StatusOK:
if err := json.Unmarshal(bits, &tenantResult); err != nil { var t tenantResult
return errors.Wrap(err, "login failed") if err := json.Unmarshal(bits, &t); err != nil {
return errors.Wrap(err, "unable to unmarshal tenant")
} }
tenantID := tenantResult.Value[0].TenantID tID := t.Value[0].TenantID
tenantToken, err := login.refreshToken(token.RefreshToken, tenantID) tToken, err := login.refreshToken(token.RefreshToken, tID)
if err != nil { if err != nil {
return errors.Wrap(err, "login failed") return errors.Wrap(err, "unable to refresh token")
} }
loginInfo := TokenInfo{TenantID: tenantID, Token: tenantToken} loginInfo := TokenInfo{TenantID: tID, Token: tToken}
err = login.tokenStore.writeLoginInfo(loginInfo) if err := login.tokenStore.writeLoginInfo(loginInfo); err != nil {
return errors.Wrap(err, "could not store login info")
if err != nil {
return errors.Wrap(err, "login failed")
} }
fmt.Println("Login Succeeded") default:
return fmt.Errorf("unable to login status code %d: %s", statusCode, bits)
return nil
} }
return fmt.Errorf("login failed : " + string(bits))
} }
return nil
} }
func getTokenStorePath() string { func getTokenStorePath() string {

View File

@ -2,7 +2,6 @@ package login
import ( import (
"context" "context"
"errors"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
@ -20,14 +19,14 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
type LoginSuiteTest struct { type LoginSuite struct {
suite.Suite suite.Suite
dir string dir string
mockHelper *MockAzureHelper mockHelper *MockAzureHelper
azureLogin AzureLoginService azureLogin AzureLoginService
} }
func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) { func (suite *LoginSuite) BeforeTest(suiteName, testName string) {
dir, err := ioutil.TempDir("", "test_store") dir, err := ioutil.TempDir("", "test_store")
Expect(err).To(BeNil()) Expect(err).To(BeNil())
@ -37,12 +36,12 @@ func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) {
Expect(err).To(BeNil()) Expect(err).To(BeNil())
} }
func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) { func (suite *LoginSuite) AfterTest(suiteName, testName string) {
err := os.RemoveAll(suite.dir) err := os.RemoveAll(suite.dir)
Expect(err).To(BeNil()) Expect(err).To(BeNil())
} }
func (suite *LoginSuiteTest) TestRefreshInValidToken() { func (suite *LoginSuite) TestRefreshInValidToken() {
data := refreshTokenData("refreshToken") data := refreshTokenData("refreshToken")
suite.mockHelper.On("queryToken", data, "123456").Return(azureToken{ suite.mockHelper.On("queryToken", data, "123456").Return(azureToken{
RefreshToken: "newRefreshToken", RefreshToken: "newRefreshToken",
@ -77,7 +76,7 @@ func (suite *LoginSuiteTest) TestRefreshInValidToken() {
Expect(storedToken.Token.Expiry).To(BeTemporally(">", time.Now().Add(3500*time.Second))) Expect(storedToken.Token.Expiry).To(BeTemporally(">", time.Now().Add(3500*time.Second)))
} }
func (suite *LoginSuiteTest) TestDoesNotRefreshValidToken() { func (suite *LoginSuite) TestDoesNotRefreshValidToken() {
expiryDate := time.Now().Add(1 * time.Hour) expiryDate := time.Now().Add(1 * time.Hour)
err := suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{ err := suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456", TenantID: "123456",
@ -96,7 +95,7 @@ func (suite *LoginSuiteTest) TestDoesNotRefreshValidToken() {
Expect(token.AccessToken).To(Equal("accessToken")) Expect(token.AccessToken).To(Equal("accessToken"))
} }
func (suite *LoginSuiteTest) TestInvalidLogin() { func (suite *LoginSuite) TestInvalidLogin() {
suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL := args.Get(0).(string) redirectURL := args.Get(0).(string)
err := queryKeyValue(redirectURL, "error", "access denied") err := queryKeyValue(redirectURL, "error", "access denied")
@ -108,10 +107,10 @@ func (suite *LoginSuiteTest) TestInvalidLogin() {
Expect(err).To(BeNil()) Expect(err).To(BeNil())
err = azureLogin.Login(context.TODO()) err = azureLogin.Login(context.TODO())
Expect(err).To(MatchError(errors.New("login failed : [access denied]"))) Expect(err.Error()).To(BeEquivalentTo("no login code"))
} }
func (suite *LoginSuiteTest) TestValidLogin() { func (suite *LoginSuite) TestValidLogin() {
var redirectURL string var redirectURL string
suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
@ -161,7 +160,7 @@ func (suite *LoginSuiteTest) TestValidLogin() {
Expect(loginToken.Token.Type()).To(Equal("Bearer")) Expect(loginToken.Token.Type()).To(Equal("Bearer"))
} }
func (suite *LoginSuiteTest) TestLoginAuthorizationFailed() { func (suite *LoginSuite) TestLoginAuthorizationFailed() {
var redirectURL string var redirectURL string
suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) { suite.mockHelper.On("openAzureLoginPage", mock.AnythingOfType("string")).Run(func(args mock.Arguments) {
redirectURL = args.Get(0).(string) redirectURL = args.Get(0).(string)
@ -193,7 +192,7 @@ func (suite *LoginSuiteTest) TestLoginAuthorizationFailed() {
Expect(err).To(BeNil()) Expect(err).To(BeNil())
err = azureLogin.Login(context.TODO()) err = azureLogin.Login(context.TODO())
Expect(err).To(MatchError(errors.New("login failed : [access denied]"))) Expect(err.Error()).To(BeEquivalentTo("unable to login status code 400: [access denied]"))
} }
func refreshTokenData(refreshToken string) url.Values { func refreshTokenData(refreshToken string) url.Values {
@ -218,7 +217,7 @@ func queryKeyValue(redirectURL string, key string, value string) error {
func TestLoginSuite(t *testing.T) { func TestLoginSuite(t *testing.T) {
RegisterTestingT(t) RegisterTestingT(t)
suite.Run(t, new(LoginSuiteTest)) suite.Run(t, new(LoginSuite))
} }
type MockAzureHelper struct { type MockAzureHelper struct {

View File

@ -1,83 +0,0 @@
package login
import (
"fmt"
"net"
"net/http"
"net/url"
)
const loginFailedHTML = `
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>Login failed</title>
</head>
<body>
<h4>Some failures occurred during the authentication</h4>
<p>You can log an issue at <a href="https://github.com/azure/azure-cli/issues">Azure CLI GitHub Repository</a> and we will assist you in resolving it.</p>
</body>
</html>
`
const successfullLoginHTML = `
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta http-equiv="refresh" content="10;url=https://docs.microsoft.com/cli/azure/">
<title>Login successfully</title>
</head>
<body>
<h4>You have logged into Microsoft Azure!</h4>
<p>You can close this window, or we will redirect you to the <a href="https://docs.microsoft.com/cli/azure/">Azure CLI documents</a> in 10 seconds.</p>
</body>
</html>
`
func startLoginServer(queryCh chan url.Values) (int, error) {
mux := http.NewServeMux()
mux.HandleFunc("/", queryHandler(queryCh))
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
availablePort := listener.Addr().(*net.TCPAddr).Port
server := &http.Server{Handler: mux}
go func() {
if err := server.Serve(listener); err != nil {
queryCh <- url.Values{
"error": []string{fmt.Sprintf("error starting http server with: %v", err)},
}
}
}()
return availablePort, nil
}
func queryHandler(queryCh chan url.Values) func(w http.ResponseWriter, r *http.Request) {
queryHandler := func(w http.ResponseWriter, r *http.Request) {
_, hasCode := r.URL.Query()["code"]
if hasCode {
_, err := w.Write([]byte(successfullLoginHTML))
if err != nil {
queryCh <- url.Values{
"error": []string{err.Error()},
}
} else {
queryCh <- r.URL.Query()
}
} else {
_, err := w.Write([]byte(loginFailedHTML))
if err != nil {
queryCh <- url.Values{
"error": []string{err.Error()},
}
} else {
queryCh <- r.URL.Query()
}
}
}
return queryHandler
}

View File

@ -1,6 +1,7 @@
package login package login
import ( import (
"fmt"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -35,7 +36,7 @@ func runLogin(cmd *cobra.Command, args []string) error {
case "azure": case "azure":
return cloudLogin(cmd, "aci") return cloudLogin(cmd, "aci")
default: default:
return errors.New("Unknown backend type for cloud login : " + backend) return errors.New("unknown backend type for cloud login: " + backend)
} }
} }
return dockerclassic.ExecCmd(cmd) return dockerclassic.ExecCmd(cmd)
@ -47,6 +48,13 @@ func cloudLogin(cmd *cobra.Command, backendType string) error {
if err != nil { if err != nil {
return errors.Wrap(err, "cannot connect to backend") return errors.Wrap(err, "cannot connect to backend")
} }
return cs.Login(ctx, nil) err = cs.Login(ctx, nil)
if err != nil {
return err
}
if cmd.Context().Err() != nil {
return errors.New("login canceled")
}
fmt.Println("login successful")
return nil
} }

View File

@ -90,7 +90,7 @@ func (s *E2eSuite) TestClassicLogin() {
func (s *E2eSuite) TestCloudLogin() { func (s *E2eSuite) TestCloudLogin() {
output, err := s.NewDockerCommand("login", "mycloudbackend").Exec() output, err := s.NewDockerCommand("login", "mycloudbackend").Exec()
Expect(output).To(ContainSubstring("Unknown backend type for cloud login : mycloudbackend")) Expect(output).To(ContainSubstring("unknown backend type for cloud login: mycloudbackend"))
Expect(err).NotTo(BeNil()) Expect(err).NotTo(BeNil())
} }