Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update based on azcore refactor #15383

Merged
merged 1 commit into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
## v0.2.0 (2020-09-11)
### Features Added
* Refactor `azidentity` on top of `azcore` refactor
* Updated policies to conform to `azcore.Policy` interface changes.
* Updated policies to conform to `policy.Policy` interface changes.
* Updated non-retriable errors to conform to `azcore.NonRetriableError`.
* Fixed calls to `Request.SetBody()` to include content type.
* Switched endpoints to string types and removed extra parsing code.
Expand Down
76 changes: 37 additions & 39 deletions sdk/azidentity/aad_identity_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
)

const (
Expand Down Expand Up @@ -45,10 +48,10 @@ type interactiveConfig struct {
}

// aadIdentityClient provides the base for authenticating with Client Secret Credentials, Client Certificate Credentials
// and Environment Credentials. This type includes an azcore.Pipeline and TokenCredentialOptions.
// and Environment Credentials. This type includes an runtime.Pipeline and TokenCredentialOptions.
type aadIdentityClient struct {
authorityHost string
pipeline azcore.Pipeline
pipeline runtime.Pipeline
}

// newAADIdentityClient creates a new instance of the aadIdentityClient with the TokenCredentialOptions
Expand All @@ -57,11 +60,6 @@ type aadIdentityClient struct {
// will be used to retrieve tokens and authenticate
func newAADIdentityClient(authorityHost string, options pipelineOptions) (*aadIdentityClient, error) {
logEnvVars()
if options.Telemetry.Value == "" {
options.Telemetry.Value = UserAgent
} else {
options.Telemetry.Value += " " + UserAgent
}
return &aadIdentityClient{authorityHost: authorityHost, pipeline: newDefaultPipeline(options)}, nil
}

Expand All @@ -83,7 +81,7 @@ func (c *aadIdentityClient) refreshAccessToken(ctx context.Context, tenantID str
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createRefreshAccessToken(resp)
}

Expand All @@ -108,7 +106,7 @@ func (c *aadIdentityClient) authenticate(ctx context.Context, tenantID string, c
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createAccessToken(resp)
}

Expand All @@ -133,7 +131,7 @@ func (c *aadIdentityClient) authenticateCertificate(ctx context.Context, tenantI
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createAccessToken(resp)
}

Expand All @@ -146,7 +144,7 @@ func (c *aadIdentityClient) createAccessToken(res *http.Response) (*azcore.Acces
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
}{}
if err := azcore.UnmarshalAsJSON(res, &value); err != nil {
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("internal AccessToken: %w", err)
}
t, err := value.ExpiresIn.Int64()
Expand All @@ -168,7 +166,7 @@ func (c *aadIdentityClient) createRefreshAccessToken(res *http.Response) (*token
ExpiresIn json.Number `json:"expires_in"`
ExpiresOn string `json:"expires_on"`
}{}
if err := azcore.UnmarshalAsJSON(res, &value); err != nil {
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("internal AccessToken: %w", err)
}
t, err := value.ExpiresIn.Int64()
Expand All @@ -182,7 +180,7 @@ func (c *aadIdentityClient) createRefreshAccessToken(res *http.Response) (*token
return &tokenResponse{token: accessToken, refreshToken: value.RefreshToken}, nil
}

func (c *aadIdentityClient) createRefreshTokenRequest(ctx context.Context, tenantID, clientID, clientSecret, refreshToken string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createRefreshTokenRequest(ctx context.Context, tenantID, clientID, clientSecret, refreshToken string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpGrantType, "refresh_token")
data.Set(qpClientID, clientID)
Expand All @@ -193,8 +191,8 @@ func (c *aadIdentityClient) createRefreshTokenRequest(ctx context.Context, tenan
data.Set(qpRefreshToken, refreshToken)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
if err != nil {
return nil, err
}
Expand All @@ -204,15 +202,15 @@ func (c *aadIdentityClient) createRefreshTokenRequest(ctx context.Context, tenan
return req, nil
}

func (c *aadIdentityClient) createClientSecretAuthRequest(ctx context.Context, tenantID string, clientID string, clientSecret string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createClientSecretAuthRequest(ctx context.Context, tenantID string, clientID string, clientSecret string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpGrantType, "client_credentials")
data.Set(qpClientID, clientID)
data.Set(qpClientSecret, clientSecret)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
if err != nil {
return nil, err
}
Expand All @@ -223,8 +221,8 @@ func (c *aadIdentityClient) createClientSecretAuthRequest(ctx context.Context, t
return req, nil
}

func (c *aadIdentityClient) createClientCertificateAuthRequest(ctx context.Context, tenantID string, clientID string, cert *certContents, sendCertificateChain bool, scopes []string) (*azcore.Request, error) {
u := azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID)))
func (c *aadIdentityClient) createClientCertificateAuthRequest(ctx context.Context, tenantID string, clientID string, cert *certContents, sendCertificateChain bool, scopes []string) (*policy.Request, error) {
u := runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID)))
clientAssertion, err := createClientAssertionJWT(clientID, u, cert, sendCertificateChain)
if err != nil {
return nil, err
Expand All @@ -237,8 +235,8 @@ func (c *aadIdentityClient) createClientCertificateAuthRequest(ctx context.Conte
data.Set(qpClientAssertion, clientAssertion)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, u)
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, u)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -267,14 +265,14 @@ func (c *aadIdentityClient) authenticateUsernamePassword(ctx context.Context, te
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createAccessToken(resp)
}

return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)}
}

func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Context, tenantID string, clientID string, username string, password string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Context, tenantID string, clientID string, username string, password string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpResponseType, "token")
data.Set(qpGrantType, "password")
Expand All @@ -283,8 +281,8 @@ func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Contex
data.Set(qpPassword, password)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
if err != nil {
return nil, err
}
Expand All @@ -296,7 +294,7 @@ func (c *aadIdentityClient) createUsernamePasswordAuthRequest(ctx context.Contex

func createDeviceCodeResult(res *http.Response) (*deviceCodeResult, error) {
value := &deviceCodeResult{}
if err := azcore.UnmarshalAsJSON(res, &value); err != nil {
if err := runtime.UnmarshalAsJSON(res, &value); err != nil {
return nil, fmt.Errorf("DeviceCodeResult: %w", err)
}
return value, nil
Expand All @@ -320,22 +318,22 @@ func (c *aadIdentityClient) authenticateDeviceCode(ctx context.Context, tenantID
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createRefreshAccessToken(resp)
}

return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)}
}

func (c *aadIdentityClient) createDeviceCodeAuthRequest(ctx context.Context, tenantID string, clientID string, deviceCode string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createDeviceCodeAuthRequest(ctx context.Context, tenantID string, clientID string, deviceCode string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpGrantType, deviceCodeGrantType)
data.Set(qpClientID, clientID)
data.Set(qpDeviceCode, deviceCode)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
if err != nil {
return nil, err
}
Expand All @@ -356,20 +354,20 @@ func (c *aadIdentityClient) requestNewDeviceCode(ctx context.Context, tenantID,
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return createDeviceCodeResult(resp)
}
return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)}
}

func (c *aadIdentityClient) createDeviceCodeNumberRequest(ctx context.Context, tenantID string, clientID string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createDeviceCodeNumberRequest(ctx context.Context, tenantID string, clientID string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpClientID, clientID)
data.Set(qpScope, strings.Join(scopes, " "))
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
// endpoint that will return a device code along with the other necessary authentication flow parameters in the DeviceCodeResult struct
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, path.Join(oauthPath(tenantID), "/devicecode")))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, path.Join(oauthPath(tenantID), "/devicecode")))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -408,15 +406,15 @@ func (c *aadIdentityClient) authenticateAuthCode(ctx context.Context, tenantID,
return nil, err
}

if azcore.HasStatusCode(resp, successStatusCodes[:]...) {
if runtime.HasStatusCode(resp, successStatusCodes[:]...) {
return c.createAccessToken(resp)
}

return nil, &AuthenticationFailedError{inner: newAADAuthenticationFailedError(resp)}
}

// createAuthorizationCodeAuthRequest creates a request for an Access Token for authorization_code grant types.
func (c *aadIdentityClient) createAuthorizationCodeAuthRequest(ctx context.Context, tenantID, clientID, authCode, clientSecret, codeVerifier, redirectURI string, scopes []string) (*azcore.Request, error) {
func (c *aadIdentityClient) createAuthorizationCodeAuthRequest(ctx context.Context, tenantID, clientID, authCode, clientSecret, codeVerifier, redirectURI string, scopes []string) (*policy.Request, error) {
data := url.Values{}
data.Set(qpGrantType, "authorization_code")
data.Set(qpClientID, clientID)
Expand All @@ -431,8 +429,8 @@ func (c *aadIdentityClient) createAuthorizationCodeAuthRequest(ctx context.Conte
data.Set(qpScope, strings.Join(scopes, " "))
data.Set(qpCode, authCode)
dataEncoded := data.Encode()
body := azcore.NopCloser(strings.NewReader(dataEncoded))
req, err := azcore.NewRequest(ctx, http.MethodPost, azcore.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
body := streaming.NopCloser(strings.NewReader(dataEncoded))
req, err := runtime.NewRequest(ctx, http.MethodPost, runtime.JoinPaths(c.authorityHost, tenantID, tokenEndpoint(oauthPath(tenantID))))
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions sdk/azidentity/aad_identity_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

Expand Down Expand Up @@ -52,7 +52,7 @@ func TestTelemetryDefaultUserAgent(t *testing.T) {
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL())
req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
Expand All @@ -63,7 +63,7 @@ func TestTelemetryDefaultUserAgent(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if ua := resp.Request.Header.Get(headerUserAgent); !strings.HasPrefix(ua, UserAgent) {
if ua := resp.Request.Header.Get(headerUserAgent); !strings.HasPrefix(ua, "azsdk-go-"+component+"/"+version) {
t.Fatalf("unexpected User-Agent %s", ua)
}
}
Expand All @@ -76,12 +76,12 @@ func TestTelemetryCustom(t *testing.T) {
options := pipelineOptions{
HTTPClient: srv,
}
options.Telemetry.Value = customTelemetry
options.Telemetry.ApplicationID = customTelemetry
client, err := newAADIdentityClient(srv.URL(), options)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL())
req, err := runtime.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
Expand All @@ -92,7 +92,7 @@ func TestTelemetryCustom(t *testing.T) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if ua := resp.Request.Header.Get(headerUserAgent); !strings.HasPrefix(ua, customTelemetry+" "+UserAgent) {
if ua := resp.Request.Header.Get(headerUserAgent); !strings.HasPrefix(ua, customTelemetry+" "+"azsdk-go-"+component+"/"+version) {
t.Fatalf("unexpected User-Agent %s", ua)
}
}
14 changes: 8 additions & 6 deletions sdk/azidentity/authorization_code_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"context"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
)

// AuthorizationCodeCredentialOptions contain optional parameters that can be used to configure the AuthorizationCodeCredential.
Expand All @@ -19,13 +21,13 @@ type AuthorizationCodeCredentialOptions struct {
AuthorityHost string
// HTTPClient sets the transport for making HTTP requests
// Leave this as nil to use the default HTTP transport
HTTPClient azcore.Transporter
HTTPClient policy.Transporter
// Retry configures the built-in retry policy behavior
Retry azcore.RetryOptions
Retry policy.RetryOptions
// Telemetry configures the built-in telemetry policy behavior
Telemetry azcore.TelemetryOptions
Telemetry policy.TelemetryOptions
// Logging configures the built-in logging policy behavior.
Logging azcore.LogOptions
Logging policy.LogOptions
}

// AuthorizationCodeCredential enables authentication to Azure Active Directory using an authorization code
Expand Down Expand Up @@ -68,7 +70,7 @@ func NewAuthorizationCodeCredential(tenantID string, clientID string, authCode s
// ctx: Context used to control the request lifetime.
// opts: TokenRequestOptions contains the list of scopes for which the token will have access.
// Returns an AccessToken which can be used to authenticate service client calls.
func (c *AuthorizationCodeCredential) GetToken(ctx context.Context, opts azcore.TokenRequestOptions) (*azcore.AccessToken, error) {
func (c *AuthorizationCodeCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (*azcore.AccessToken, error) {
tk, err := c.client.authenticateAuthCode(ctx, c.tenantID, c.clientID, c.authCode, c.clientSecret, "", c.redirectURI, opts.Scopes)
if err != nil {
addGetTokenFailureLogs("Authorization Code Credential", err, true)
Expand All @@ -79,7 +81,7 @@ func (c *AuthorizationCodeCredential) GetToken(ctx context.Context, opts azcore.
}

// NewAuthenticationPolicy implements the azcore.Credential interface on AuthorizationCodeCredential.
func (c *AuthorizationCodeCredential) NewAuthenticationPolicy(options azcore.AuthenticationOptions) azcore.Policy {
func (c *AuthorizationCodeCredential) NewAuthenticationPolicy(options runtime.AuthenticationOptions) policy.Policy {
return newBearerTokenPolicy(c, options)
}

Expand Down
Loading