diff --git a/aci/backend.go b/aci/backend.go index 1a28e9550..009ddd20a 100644 --- a/aci/backend.go +++ b/aci/backend.go @@ -64,7 +64,19 @@ type ContextParams struct { // LoginParams azure login options type LoginParams struct { - TenantID string + TenantID string + ClientID string + ClientSecret string +} + +// Validate returns an error if options are not used properly +func (opts LoginParams) Validate() error { + if opts.ClientID != "" || opts.ClientSecret != "" { + if opts.ClientID == "" || opts.ClientSecret == "" || opts.TenantID == "" { + return errors.New("for Service Principal login, 3 options must be specified: --client-id, --client-secret and --tenant-id") + } + } + return nil } func init() { @@ -377,12 +389,18 @@ func (cs *aciComposeService) Logs(ctx context.Context, opts cli.ProjectOptions) } type aciCloudService struct { - loginService *login.AzureLoginService + loginService login.AzureLoginServiceAPI } func (cs *aciCloudService) Login(ctx context.Context, params interface{}) error { - createOpts := params.(LoginParams) - return cs.loginService.Login(ctx, createOpts.TenantID) + opts, ok := params.(LoginParams) + if !ok { + return errors.New("Could not read azure LoginParams struct from generic parameter") + } + if opts.ClientID != "" { + return cs.loginService.LoginServicePrincipal(opts.ClientID, opts.ClientSecret, opts.TenantID) + } + return cs.loginService.Login(ctx, opts.TenantID) } func (cs *aciCloudService) Logout(ctx context.Context) error { diff --git a/aci/backend_test.go b/aci/backend_test.go index e6b72765c..14c8b92e8 100644 --- a/aci/backend_test.go +++ b/aci/backend_test.go @@ -20,9 +20,10 @@ import ( "context" "testing" - "github.com/docker/api/containers" - + "github.com/stretchr/testify/mock" "gotest.tools/v3/assert" + + "github.com/docker/api/containers" ) func TestGetContainerName(t *testing.T) { @@ -58,3 +59,86 @@ func TestVerifyCommand(t *testing.T) { assert.Error(t, err, "ACI exec command does not accept arguments to the command. "+ "Only the binary should be specified") } + +func TestLoginParamsValidate(t *testing.T) { + err := LoginParams{ + ClientID: "someID", + }.Validate() + assert.Error(t, err, "for Service Principal login, 3 options must be specified: --client-id, --client-secret and --tenant-id") + + err = LoginParams{ + ClientSecret: "someSecret", + }.Validate() + assert.Error(t, err, "for Service Principal login, 3 options must be specified: --client-id, --client-secret and --tenant-id") + + err = LoginParams{}.Validate() + assert.NilError(t, err) + + err = LoginParams{ + TenantID: "tenant", + }.Validate() + assert.NilError(t, err) +} + +func TestLoginServicePrincipal(t *testing.T) { + loginService := mockLoginService{} + loginService.On("LoginServicePrincipal", "someID", "secret", "tenant").Return(nil) + loginBackend := aciCloudService{ + loginService: &loginService, + } + + err := loginBackend.Login(context.Background(), LoginParams{ + ClientID: "someID", + ClientSecret: "secret", + TenantID: "tenant", + }) + + assert.NilError(t, err) +} + +func TestLoginWithTenant(t *testing.T) { + loginService := mockLoginService{} + ctx := context.Background() + loginService.On("Login", ctx, "tenant").Return(nil) + loginBackend := aciCloudService{ + loginService: &loginService, + } + + err := loginBackend.Login(ctx, LoginParams{ + TenantID: "tenant", + }) + + assert.NilError(t, err) +} + +func TestLoginWithoutTenant(t *testing.T) { + loginService := mockLoginService{} + ctx := context.Background() + loginService.On("Login", ctx, "").Return(nil) + loginBackend := aciCloudService{ + loginService: &loginService, + } + + err := loginBackend.Login(ctx, LoginParams{}) + + assert.NilError(t, err) +} + +type mockLoginService struct { + mock.Mock +} + +func (s *mockLoginService) Login(ctx context.Context, requestedTenantID string) error { + args := s.Called(ctx, requestedTenantID) + return args.Error(0) +} + +func (s *mockLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { + args := s.Called(clientID, clientSecret, tenantID) + return args.Error(0) +} + +func (s *mockLoginService) Logout(ctx context.Context) error { + args := s.Called(ctx) + return args.Error(0) +} diff --git a/aci/login/login.go b/aci/login/login.go index d3533a4f2..f7bacdc02 100644 --- a/aci/login/login.go +++ b/aci/login/login.go @@ -72,6 +72,13 @@ type AzureLoginService struct { apiHelper apiHelper } +// AzureLoginServiceAPI interface for Azure login service +type AzureLoginServiceAPI interface { + LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error + Login(ctx context.Context, requestedTenantID string) error + Logout(ctx context.Context) error +} + const tokenStoreFilename = "dockerAccessToken.json" // NewAzureLoginService creates a NewAzureLoginService @@ -90,9 +97,9 @@ func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (*Azu }, nil } -// TestLoginFromServicePrincipal login with clientId / clientSecret from a previously created service principal. -// The resulting token does not include a refresh token, used for tests only -func (login *AzureLoginService) TestLoginFromServicePrincipal(clientID string, clientSecret string, tenantID string) error { +// LoginServicePrincipal login with clientId / clientSecret from a service principal. +// The resulting token does not include a refresh token +func (login *AzureLoginService) LoginServicePrincipal(clientID string, clientSecret string, tenantID string) error { // Tried with auth2.NewUsernamePasswordConfig() but could not make this work with username / password, setting this for CI with clientID / clientSecret creds := auth2.NewClientCredentialsConfig(clientID, clientSecret, tenantID) diff --git a/cli/cmd/login/azurelogin.go b/cli/cmd/login/azurelogin.go index 4d80413c2..4b7f3a28f 100644 --- a/cli/cmd/login/azurelogin.go +++ b/cli/cmd/login/azurelogin.go @@ -14,11 +14,16 @@ func AzureLoginCommand() *cobra.Command { Short: "Log in to azure", Args: cobra.MaximumNArgs(0), RunE: func(cmd *cobra.Command, args []string) error { + if err := opts.Validate(); err != nil { + return err + } return cloudLogin(cmd, "aci", opts) }, } flags := cmd.Flags() - flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use from your azure account") + flags.StringVar(&opts.TenantID, "tenant-id", "", "Specify tenant ID to use") + flags.StringVar(&opts.ClientID, "client-id", "", "Client ID for Service principal login") + flags.StringVar(&opts.ClientSecret, "client-secret", "", "Client secret for Service principal login") return cmd } diff --git a/tests/aci-e2e/e2e-aci_test.go b/tests/aci-e2e/e2e-aci_test.go index 19831228a..c85e6ed77 100644 --- a/tests/aci-e2e/e2e-aci_test.go +++ b/tests/aci-e2e/e2e-aci_test.go @@ -76,7 +76,7 @@ func TestLoginLogout(t *testing.T) { rg := "E2E-" + startTime t.Run("login", func(t *testing.T) { - azureLogin(t) + azureLogin(t, c) }) t.Run("create context", func(t *testing.T) { @@ -506,7 +506,7 @@ func TestRunEnvVars(t *testing.T) { func setupTestResourceGroup(t *testing.T, c *E2eCLI, tName string) (string, string) { startTime := strconv.Itoa(int(time.Now().Unix())) rg := "E2E-" + tName + "-" + startTime - azureLogin(t) + azureLogin(t, c) sID := getSubscriptionID(t) t.Logf("Create resource group %q", rg) err := createResourceGroup(sID, rg) @@ -537,17 +537,14 @@ func deleteResourceGroup(rgName string) error { return helper.DeleteAsync(ctx, *models[0].SubscriptionID, rgName) } -func azureLogin(t *testing.T) { +func azureLogin(t *testing.T, c *E2eCLI) { t.Log("Log in to Azure") - login, err := login.NewAzureLoginService() - assert.NilError(t, err) - // in order to create new service principal and get these 3 values : `az ad sp create-for-rbac --name 'TestServicePrincipal' --sdk-auth` clientID := os.Getenv("AZURE_CLIENT_ID") clientSecret := os.Getenv("AZURE_CLIENT_SECRET") tenantID := os.Getenv("AZURE_TENANT_ID") - err = login.TestLoginFromServicePrincipal(clientID, clientSecret, tenantID) - assert.NilError(t, err) + res := c.RunDockerCmd("login", "azure", "--client-id", clientID, "--client-secret", clientSecret, "--tenant-id", tenantID) + res.Assert(t, icmd.Success) } func getSubscriptionID(t *testing.T) string {