Fix tokenStore not creating ~/.azure folder if not exist

This commit is contained in:
Guillaume Tardif 2020-05-13 23:33:16 +02:00
parent 8b116b7c73
commit 146dd3e639
6 changed files with 132 additions and 41 deletions

View File

@ -235,7 +235,7 @@ func getACIContainerLogs(ctx context.Context, aciContext store.AciContext, conta
} }
func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) { func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) {
auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin() auth, err := login.NewAuthorizerFromLogin()
if err != nil { if err != nil {
return containerinstance.ContainerGroupsClient{}, err return containerinstance.ContainerGroupsClient{}, err
} }
@ -248,7 +248,7 @@ func getContainerGroupsClient(subscriptionID string) (containerinstance.Containe
} }
func getContainerClient(subscriptionID string) (containerinstance.ContainerClient, error) { func getContainerClient(subscriptionID string) (containerinstance.ContainerClient, error) {
auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin() auth, err := login.NewAuthorizerFromLogin()
if err != nil { if err != nil {
return containerinstance.ContainerClient{}, err return containerinstance.ContainerClient{}, err
} }
@ -259,7 +259,7 @@ func getContainerClient(subscriptionID string) (containerinstance.ContainerClien
func getSubscriptionsClient() subscription.SubscriptionsClient { func getSubscriptionsClient() subscription.SubscriptionsClient {
subc := subscription.NewSubscriptionsClient() subc := subscription.NewSubscriptionsClient()
authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() authorizer, _ := login.NewAuthorizerFromLogin()
subc.Authorizer = authorizer subc.Authorizer = authorizer
return subc return subc
} }
@ -267,7 +267,7 @@ func getSubscriptionsClient() subscription.SubscriptionsClient {
// GetGroupsClient ... // GetGroupsClient ...
func GetGroupsClient(subscriptionID string) resources.GroupsClient { func GetGroupsClient(subscriptionID string) resources.GroupsClient {
groupsClient := resources.NewGroupsClient(subscriptionID) groupsClient := resources.NewGroupsClient(subscriptionID)
authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() authorizer, _ := login.NewAuthorizerFromLogin()
groupsClient.Authorizer = authorizer groupsClient.Authorizer = authorizer
return groupsClient return groupsClient
} }

View File

@ -52,14 +52,18 @@ func New(ctx context.Context) (backend.Service, error) {
} }
aciContext, _ := metadata.Metadata.Data.(store.AciContext) aciContext, _ := metadata.Metadata.Data.(store.AciContext)
auth, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() auth, _ := login.NewAuthorizerFromLogin()
containerGroupsClient := containerinstance.NewContainerGroupsClient(aciContext.SubscriptionID) containerGroupsClient := containerinstance.NewContainerGroupsClient(aciContext.SubscriptionID)
containerGroupsClient.Authorizer = auth containerGroupsClient.Authorizer = auth
return getAciAPIService(containerGroupsClient, aciContext), nil return getAciAPIService(containerGroupsClient, aciContext)
} }
func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) *aciAPIService { func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) (*aciAPIService, error) {
service, err := login.NewAzureLoginService()
if err != nil {
return nil, err
}
return &aciAPIService{ return &aciAPIService{
aciContainerService: aciContainerService{ aciContainerService: aciContainerService{
containerGroupsClient: cgc, containerGroupsClient: cgc,
@ -69,9 +73,9 @@ func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.
ctx: aciCtx, ctx: aciCtx,
}, },
aciCloudService: aciCloudService{ aciCloudService: aciCloudService{
loginService: login.NewAzureLoginService(), loginService: service,
}, },
} }, nil
} }
type aciAPIService struct { type aciAPIService struct {

View File

@ -68,25 +68,27 @@ type AzureLoginService struct {
apiHelper apiHelper apiHelper apiHelper
} }
const tokenFilename = "dockerAccessToken.json" const tokenStoreFilename = "dockerAccessToken.json"
func getTokenStorePath() string { func getTokenStorePath() string {
cliPath, _ := cli.AccessTokensPath() cliPath, _ := cli.AccessTokensPath()
return filepath.Join(filepath.Dir(cliPath), tokenFilename) return filepath.Join(filepath.Dir(cliPath), tokenStoreFilename)
} }
// NewAzureLoginService creates a NewAzureLoginService // NewAzureLoginService creates a NewAzureLoginService
func NewAzureLoginService() AzureLoginService { func NewAzureLoginService() (AzureLoginService, error) {
return newAzureLoginServiceFromPath(getTokenStorePath(), azureAPIHelper{}) return newAzureLoginServiceFromPath(getTokenStorePath(), azureAPIHelper{})
} }
func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) AzureLoginService { func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (AzureLoginService, error) {
return AzureLoginService{ store, err := newTokenStore(tokenStorePath)
tokenStore: tokenStore{ if err != nil {
filePath: tokenStorePath, return AzureLoginService{}, err
},
apiHelper: helper,
} }
return AzureLoginService{
tokenStore: store,
apiHelper: helper,
}, nil
} }
type apiHelper interface { type apiHelper interface {
@ -229,20 +231,21 @@ func queryHandler(queryCh chan url.Values) func(w http.ResponseWriter, r *http.R
return queryHandler return queryHandler
} }
func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) { func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) {
res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
if err != nil { if err != nil {
return token, err return azureToken{}, err
} }
if res.StatusCode != 200 { if res.StatusCode != 200 {
return token, errors.Errorf("error while renewing access token, status : %s", res.Status) return azureToken{}, errors.Errorf("error while renewing access token, status : %s", res.Status)
} }
bits, err := ioutil.ReadAll(res.Body) bits, err := ioutil.ReadAll(res.Body)
if err != nil { if err != nil {
return token, err return azureToken{}, err
} }
token := azureToken{}
if err := json.Unmarshal(bits, &token); err != nil { if err := json.Unmarshal(bits, &token); err != nil {
return token, err return azureToken{}, err
} }
return token, nil return token, nil
} }
@ -259,7 +262,11 @@ func toOAuthToken(token azureToken) oauth2.Token {
} }
// NewAuthorizerFromLogin creates an authorizer based on login access token // NewAuthorizerFromLogin creates an authorizer based on login access token
func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, error) { func NewAuthorizerFromLogin() (autorest.Authorizer, error) {
login, err := NewAzureLoginService()
if err != nil {
return nil, err
}
oauthToken, err := login.GetValidToken() oauthToken, err := login.GetValidToken()
if err != nil { if err != nil {
return nil, err return nil, err
@ -278,28 +285,28 @@ func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, er
} }
// GetValidToken returns an access token. Refresh token if needed // GetValidToken returns an access token. Refresh token if needed
func (login AzureLoginService) GetValidToken() (token oauth2.Token, err error) { func (login AzureLoginService) GetValidToken() (oauth2.Token, error) {
loginInfo, err := login.tokenStore.readToken() loginInfo, err := login.tokenStore.readToken()
if err != nil { if err != nil {
return token, err return oauth2.Token{}, err
} }
token = loginInfo.Token token := loginInfo.Token
if token.Valid() { if token.Valid() {
return token, nil return token, nil
} }
tenantID := loginInfo.TenantID tenantID := loginInfo.TenantID
token, err = login.refreshToken(token.RefreshToken, tenantID) token, err = login.refreshToken(token.RefreshToken, tenantID)
if err != nil { if err != nil {
return token, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.") return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.")
} }
err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token}) err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token})
if err != nil { if err != nil {
return token, err return oauth2.Token{}, err
} }
return token, nil return token, nil
} }
func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauthToken oauth2.Token, err error) { func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) {
data := url.Values{ data := url.Values{
"grant_type": []string{"refresh_token"}, "grant_type": []string{"refresh_token"},
"client_id": []string{clientID}, "client_id": []string{clientID},
@ -308,7 +315,7 @@ func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID
} }
token, err := login.apiHelper.queryToken(data, tenantID) token, err := login.apiHelper.queryToken(data, tenantID)
if err != nil { if err != nil {
return oauthToken, err return oauth2.Token{}, err
} }
return toOAuthToken(token), nil return toOAuthToken(token), nil

View File

@ -8,8 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -27,17 +25,18 @@ type LoginSuiteTest struct {
func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) { func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) {
dir, err := ioutil.TempDir("", "test_store") dir, err := ioutil.TempDir("", "test_store")
require.Nil(suite.T(), err) Expect(err).To(BeNil())
suite.dir = dir suite.dir = dir
suite.mockHelper = MockAzureHelper{} suite.mockHelper = MockAzureHelper{}
//nolint copylocks //nolint copylocks
suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(dir, tokenFilename), suite.mockHelper) suite.azureLogin, err = newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), suite.mockHelper)
Expect(err).To(BeNil())
} }
func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) { func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) {
err := os.RemoveAll(suite.dir) err := os.RemoveAll(suite.dir)
require.Nil(suite.T(), err) Expect(err).To(BeNil())
} }
func (suite *LoginSuiteTest) TestRefreshInValidToken() { func (suite *LoginSuiteTest) TestRefreshInValidToken() {
@ -55,8 +54,10 @@ func (suite *LoginSuiteTest) TestRefreshInValidToken() {
}, nil) }, nil)
//nolint copylocks //nolint copylocks
suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenFilename), suite.mockHelper) azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper)
err := suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{ Expect(err).To(BeNil())
suite.azureLogin = azureLogin
err = suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{
TenantID: "123456", TenantID: "123456",
Token: oauth2.Token{ Token: oauth2.Token{
AccessToken: "accessToken", AccessToken: "accessToken",

View File

@ -2,7 +2,10 @@ package login
import ( import (
"encoding/json" "encoding/json"
"errors"
"io/ioutil" "io/ioutil"
"os"
"path/filepath"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -17,6 +20,27 @@ type TokenInfo struct {
TenantID string `json:"tenantId"` TenantID string `json:"tenantId"`
} }
func newTokenStore(path string) (tokenStore, error) {
parentFolder := filepath.Dir(path)
dir, err := os.Stat(parentFolder)
if os.IsNotExist(err) {
err = os.MkdirAll(parentFolder, 0700)
if err != nil {
return tokenStore{}, err
}
dir, err = os.Stat(parentFolder)
}
if err != nil {
return tokenStore{}, err
}
if !dir.Mode().IsDir() {
return tokenStore{}, errors.New("cannot use path " + path + " ; " + parentFolder + " already exists and is not a directory")
}
return tokenStore{
filePath: path,
}, nil
}
func (store tokenStore) writeLoginInfo(info TokenInfo) error { func (store tokenStore) writeLoginInfo(info TokenInfo) error {
bytes, err := json.MarshalIndent(info, "", " ") bytes, err := json.MarshalIndent(info, "", " ")
if err != nil { if err != nil {
@ -25,13 +49,14 @@ func (store tokenStore) writeLoginInfo(info TokenInfo) error {
return ioutil.WriteFile(store.filePath, bytes, 0644) return ioutil.WriteFile(store.filePath, bytes, 0644)
} }
func (store tokenStore) readToken() (loginInfo TokenInfo, err error) { func (store tokenStore) readToken() (TokenInfo, error) {
bytes, err := ioutil.ReadFile(store.filePath) bytes, err := ioutil.ReadFile(store.filePath)
if err != nil { if err != nil {
return loginInfo, err return TokenInfo{}, err
} }
loginInfo := TokenInfo{}
if err := json.Unmarshal(bytes, &loginInfo); err != nil { if err := json.Unmarshal(bytes, &loginInfo); err != nil {
return loginInfo, err return TokenInfo{}, err
} }
return loginInfo, nil return loginInfo, nil
} }

View File

@ -0,0 +1,54 @@
package login
import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"testing"
. "github.com/onsi/gomega"
"github.com/stretchr/testify/suite"
)
type tokenStoreTestSuite struct {
suite.Suite
}
func (suite *tokenStoreTestSuite) TestCreateStoreFromExistingFolder() {
existingDir, err := ioutil.TempDir("", "test_store")
Expect(err).To(BeNil())
storePath := filepath.Join(existingDir, tokenStoreFilename)
store, err := newTokenStore(storePath)
Expect(err).To(BeNil())
Expect((store.filePath)).To(Equal(storePath))
}
func (suite *tokenStoreTestSuite) TestCreateStoreFromNonExistingFolder() {
existingDir, err := ioutil.TempDir("", "test_store")
Expect(err).To(BeNil())
storePath := filepath.Join(existingDir, "new", tokenStoreFilename)
store, err := newTokenStore(storePath)
Expect(err).To(BeNil())
Expect((store.filePath)).To(Equal(storePath))
newDir, err := os.Stat(filepath.Join(existingDir, "new"))
Expect(err).To(BeNil())
Expect(newDir.Mode().IsDir()).To(BeTrue())
}
func (suite *tokenStoreTestSuite) TestErrorIfParentFolderIsAFile() {
existingDir, err := ioutil.TempFile("", "test_store")
Expect(err).To(BeNil())
storePath := filepath.Join(existingDir.Name(), tokenStoreFilename)
_, err = newTokenStore(storePath)
Expect(err).To(MatchError(errors.New("cannot use path " + storePath + " ; " + existingDir.Name() + " already exists and is not a directory")))
}
func TestTokenStoreSuite(t *testing.T) {
RegisterTestingT(t)
suite.Run(t, new(tokenStoreTestSuite))
}