diff --git a/google/auth/_default.py b/google/auth/_default.py index a2a07800a..bef09659b 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -35,12 +35,14 @@ _AUTHORIZED_USER_TYPE = "authorized_user" _SERVICE_ACCOUNT_TYPE = "service_account" _EXTERNAL_ACCOUNT_TYPE = "external_account" +_EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE = "external_account_authorized_user" _IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account" _GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account" _VALID_TYPES = ( _AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE, _EXTERNAL_ACCOUNT_TYPE, + _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE, _IMPERSONATED_SERVICE_ACCOUNT_TYPE, _GDCH_SERVICE_ACCOUNT_TYPE, ) @@ -158,6 +160,12 @@ def _load_credentials_from_info( default_scopes=default_scopes, request=request, ) + + elif credential_type == _EXTERNAL_ACCOUNT_AUTHORIZED_USER_TYPE: + credentials, project_id = _get_external_account_authorized_user_credentials( + filename, info, request + ) + elif credential_type == _IMPERSONATED_SERVICE_ACCOUNT_TYPE: credentials, project_id = _get_impersonated_service_account_credentials( filename, info, scopes @@ -363,6 +371,23 @@ def _get_external_account_credentials( return credentials, credentials.get_project_id(request=request) +def _get_external_account_authorized_user_credentials( + filename, info, scopes=None, default_scopes=None, request=None +): + try: + from google.auth import external_account_authorized_user + + credentials = external_account_authorized_user.Credentials.from_info(info) + except ValueError: + raise exceptions.DefaultCredentialsError( + "Failed to load external account authorized user credentials from {}".format( + filename + ) + ) + + return credentials, None + + def _get_authorized_user_credentials(filename, info, scopes=None): from google.oauth2 import credentials diff --git a/google/auth/external_account_authorized_user.py b/google/auth/external_account_authorized_user.py new file mode 100644 index 000000000..c0ffc49f3 --- /dev/null +++ b/google/auth/external_account_authorized_user.py @@ -0,0 +1,290 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External Account Authorized User Credentials. +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, these are sourced using external identities via Workforce Identity Federation. + +Obtaining the initial access and refresh token can be done through the Google Cloud CLI. + +Example credential: +{ + "type": "external_account_authorized_user", + "audience": "//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID", + "refresh_token": "refreshToken", + "token_url": "https://sts.googleapis.com/v1/oauth/token", + "token_info_url": "https://sts.googleapis.com/v1/instrospect", + "client_id": "clientId", + "client_secret": "clientSecret" +} +""" + +import datetime +import io +import json + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.oauth2 import sts +from google.oauth2 import utils + +_EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE = "external_account_authorized_user" + + +class Credentials( + credentials.CredentialsWithQuotaProject, + credentials.ReadOnlyScoped, + credentials.CredentialsWithTokenUri, +): + """Credentials for External Account Authorized Users. + + This is used to instantiate Credentials for exchanging refresh tokens from + authorized users for Google access token and authorizing requests to Google + APIs. + + The credentials are considered immutable. If you want to modify the + quota project, use `with_quota_project` and if you want to modify the token + uri, use `with_token_uri`. + """ + + def __init__( + self, + token=None, + expiry=None, + refresh_token=None, + audience=None, + client_id=None, + client_secret=None, + token_url=None, + token_info_url=None, + revoke_url=None, + quota_project_id=None, + ): + """Instantiates a external account authorized user credentials object. + + Args: + token (str): The OAuth 2.0 access token. Can be None if refresh information + is provided. + expiry (datetime.datetime): The optional expiration datetime of the OAuth 2.0 access + token. + refresh_token (str): The optional OAuth 2.0 refresh token. If specified, + credentials can be refreshed. + audience (str): The optional STS audience which contains the resource name for the workforce + pool and the provider identifier in that pool. + client_id (str): The OAuth 2.0 client ID. Must be specified for refresh, can be left as + None if the token can not be refreshed. + client_secret (str): The OAuth 2.0 client secret. Must be specified for refresh, can be + left as None if the token can not be refreshed. + token_url (str): The optional STS token exchange endpoint. Must be specified fro refresh, + can be leftas None if the token can not be refreshed. + token_info_url (str): The optional STS endpoint URL for token introspection. + revoke_url (str): The optional STS endpoint URL for revoking tokens. + quota_project_id (str): The optional project ID used for quota and billing. + This project may be different from the project used to + create the credentials. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + """ + if not any((refresh_token, token)): + raise ValueError("Either `refresh_token` or `token` should be set.") + + super(Credentials, self).__init__() + + self.token = token + self.expiry = expiry + self._audience = audience + self._refresh_token = refresh_token + self._token_url = token_url + self._token_info_url = token_info_url + self._client_id = client_id + self._client_secret = client_secret + self._revoke_url = revoke_url + self._quota_project_id = quota_project_id + + self._client_auth = None + if self._client_id: + self._client_auth = utils.ClientAuthentication( + utils.ClientAuthType.basic, self._client_id, self._client_secret + ) + self._sts_client = sts.Client(self._token_url, self._client_auth) + + @property + def info(self): + """Generates the serializable dictionary representation of the current + credentials. + + Returns: + Mapping: The dictionary representation of the credentials. This is the + reverse of the "from_info" method defined in this class. It is + useful for serializing the current credentials so it can deserialized + later. + """ + config_info = self.constructor_args() + config_info.update(type=_EXTERNAL_ACCOUNT_AUTHORIZED_USER_JSON_TYPE) + if config_info["expiry"]: + config_info["expiry"] = config_info["expiry"].isoformat() + "Z" + + return {key: value for key, value in config_info.items() if value is not None} + + def constructor_args(self): + return { + "audience": self._audience, + "refresh_token": self._refresh_token, + "token_url": self._token_url, + "token_info_url": self._token_info_url, + "client_id": self._client_id, + "client_secret": self._client_secret, + "token": self.token, + "expiry": self.expiry, + "revoke_url": self._revoke_url, + "quota_project_id": self._quota_project_id, + } + + @property + def requires_scopes(self): + """ False: OAuth 2.0 credentials have their scopes set when + the initial token is requested and can not be changed.""" + return False + + @property + def is_user(self): + """ True: This credential always represents a user.""" + return True + + def get_project_id(self): + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. + + When not determinable, None is returned. + """ + + return None + + def to_json(self, strip=None): + """Utility function that creates a JSON representation of this + credential. + Args: + strip (Sequence[str]): Optional list of members to exclude from the + generated JSON. + Returns: + str: A JSON representation of this instance. When converted into + a dictionary, it can be passed to from_info() + to create a new instance. + """ + strip = strip if strip else [] + return json.dumps({k: v for (k, v) in self.info.items() if k not in strip}) + + def refresh(self, request): + """Refreshes the access token. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + + Raises: + google.auth.exceptions.RefreshError: If the credentials could + not be refreshed. + """ + if not all( + (self._refresh_token, self._token_url, self._client_id, self._client_secret) + ): + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_url, client_id, and client_secret." + ) + + now = _helpers.utcnow() + response_data = self._make_sts_request(request) + + self.token = response_data.get("access_token") + + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + + if "refresh_token" in response_data: + self._refresh_token = response_data["refresh_token"] + + def _make_sts_request(self, request): + return self._sts_client.refresh_token(request, self._refresh_token) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + kwargs = self.constructor_args() + kwargs.update(quota_project_id=quota_project_id) + return self.__class__(**kwargs) + + @_helpers.copy_docstring(credentials.CredentialsWithTokenUri) + def with_token_uri(self, token_uri): + kwargs = self.constructor_args() + kwargs.update(token_url=token_uri) + return self.__class__(**kwargs) + + @classmethod + def from_info(cls, info, **kwargs): + """Creates a Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + + Raises: + ValueError: For invalid parameters. + """ + expiry = info.get("expiry") + if expiry: + expiry = datetime.datetime.strptime( + expiry.rstrip("Z").split(".")[0], "%Y-%m-%dT%H:%M:%S" + ) + return cls( + audience=info.get("audience"), + refresh_token=info.get("refresh_token"), + token_url=info.get("token_url"), + token_info_url=info.get("token_info_url"), + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + token=info.get("token"), + expiry=expiry, + revoke_url=info.get("revoke_url"), + quota_project_id=info.get("quota_project_id"), + **kwargs + ) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates a Credentials instance from an external account json file. + + Args: + filename (str): The path to the external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.external_account_authorized_user.Credentials: The + constructed credentials. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_info(data, **kwargs) diff --git a/google/oauth2/sts.py b/google/oauth2/sts.py index ae3c0146b..5cf06d4d4 100644 --- a/google/oauth2/sts.py +++ b/google/oauth2/sts.py @@ -58,6 +58,41 @@ def __init__(self, token_exchange_endpoint, client_authentication=None): super(Client, self).__init__(client_authentication) self._token_exchange_endpoint = token_exchange_endpoint + def _make_request(self, request, headers, request_body): + # Initialize request headers. + request_headers = _URLENCODED_HEADERS.copy() + + # Inject additional headers. + if headers: + for k, v in dict(headers).items(): + request_headers[k] = v + + # Apply OAuth client authentication. + self.apply_client_authentication_options(request_headers, request_body) + + # Execute request. + response = request( + url=self._token_exchange_endpoint, + method="POST", + headers=request_headers, + body=urllib.parse.urlencode(request_body).encode("utf-8"), + ) + + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + # If non-200 response received, translate to OAuthError exception. + if response.status != http_client.OK: + utils.handle_error_response(response_body) + + response_data = json.loads(response_body) + + # Return successful response. + return response_data + def exchange_token( self, request, @@ -102,12 +137,6 @@ def exchange_token( google.auth.exceptions.OAuthError: If the token endpoint returned an error. """ - # Initialize request headers. - headers = _URLENCODED_HEADERS.copy() - # Inject additional headers. - if additional_headers: - for k, v in dict(additional_headers).items(): - headers[k] = v # Initialize request body. request_body = { "grant_type": grant_type, @@ -128,28 +157,21 @@ def exchange_token( for k, v in dict(request_body).items(): if v is None or v == "": del request_body[k] - # Apply OAuth client authentication. - self.apply_client_authentication_options(headers, request_body) - # Execute request. - response = request( - url=self._token_exchange_endpoint, - method="POST", - headers=headers, - body=urllib.parse.urlencode(request_body).encode("utf-8"), - ) + return self._make_request(request, additional_headers, request_body) - response_body = ( - response.data.decode("utf-8") - if hasattr(response.data, "decode") - else response.data - ) + def refresh_token(self, request, refresh_token): + """Exchanges a refresh token for an access token based on the + RFC6749 spec. - # If non-200 response received, translate to OAuthError exception. - if response.status != http_client.OK: - utils.handle_error_response(response_body) - - response_data = json.loads(response_body) + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + subject_token (str): The OAuth 2.0 refresh token. + """ - # Return successful response. - return response_data + return self._make_request( + request, + None, + {"grant_type": "refresh_token", "refresh_token": refresh_token}, + ) diff --git a/tests/data/external_account_authorized_user.json b/tests/data/external_account_authorized_user.json new file mode 100644 index 000000000..e0bd20c8f --- /dev/null +++ b/tests/data/external_account_authorized_user.json @@ -0,0 +1,9 @@ +{ + "type": "external_account_authorized_user", + "audience": "//iam.googleapis.com/locations/global/workforcePools/$WORKFORCE_POOL_ID/providers/$PROVIDER_ID", + "refresh_token": "refreshToken", + "token_url": "https://sts.googleapis.com/v1/oauth/token", + "token_info_url": "https://sts.googleapis.com/v1/instrospect", + "client_id": "clientId", + "client_secret": "clientSecret" +} diff --git a/tests/oauth2/test_sts.py b/tests/oauth2/test_sts.py index f61a1d338..a543d42a8 100644 --- a/tests/oauth2/test_sts.py +++ b/tests/oauth2/test_sts.py @@ -50,6 +50,11 @@ class TestStsClient(object): "expires_in": 3600, "scope": "scope1 scope2", } + SUCCESS_RESPONSE_WITH_REFRESH = { + "access_token": "abc", + "refresh_token": "xyz", + "expires_in": 3600, + } ERROR_RESPONSE = { "error": "invalid_request", "error_description": "Invalid subject token", @@ -393,3 +398,83 @@ def test_exchange_token_non200_with_reqbody_auth(self): assert excinfo.match( r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" ) + + def test_refresh_token_success(self): + """Test refresh token with successful response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.refresh_token(request, "refreshtoken") + + headers = { + "Authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", + "Content-Type": "application/x-www-form-urlencoded", + } + request_data = {"grant_type": "refresh_token", "refresh_token": "refreshtoken"} + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_refresh_token_success_with_refresh(self): + """Test refresh token with successful response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE_WITH_REFRESH + ) + + response = client.refresh_token(request, "refreshtoken") + + headers = { + "Authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", + "Content-Type": "application/x-www-form-urlencoded", + } + request_data = {"grant_type": "refresh_token", "refresh_token": "refreshtoken"} + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert response == self.SUCCESS_RESPONSE_WITH_REFRESH + + def test_refresh_token_failure(self): + """Test refresh token with failure response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client.refresh_token(request, "refreshtoken") + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + + def test__make_request_success(self): + """Test base method with successful response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client._make_request(request, {"a": "b"}, {"c": "d"}) + + headers = { + "Authorization": "Basic dXNlcm5hbWU6cGFzc3dvcmQ=", + "Content-Type": "application/x-www-form-urlencoded", + "a": "b", + } + request_data = {"c": "d"} + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_make_request_failure(self): + """Test refresh token with failure response.""" + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client._make_request(request, {"a": "b"}, {"c": "d"}) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) diff --git a/tests/test__default.py b/tests/test__default.py index 5ea9c73c5..11d87f4cb 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -26,6 +26,7 @@ from google.auth import environment_vars from google.auth import exceptions from google.auth import external_account +from google.auth import external_account_authorized_user from google.auth import identity_pool from google.auth import impersonated_credentials from google.auth import pluggable @@ -151,6 +152,9 @@ DATA_DIR, "impersonated_service_account_service_account_source.json" ) +EXTERNAL_ACCOUNT_AUTHORIZED_USER_FILE = os.path.join( + DATA_DIR, "external_account_authorized_user.json" +) MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject) MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS @@ -540,6 +544,29 @@ def test__get_explicit_environ_credentials_no_env(): assert _default._get_explicit_environ_credentials() == (None, None) +def test_load_credentials_from_file_external_account_authorized_user(): + credentials, project_id = _default.load_credentials_from_file( + EXTERNAL_ACCOUNT_AUTHORIZED_USER_FILE, request=mock.sentinel.request + ) + + assert isinstance(credentials, external_account_authorized_user.Credentials) + assert project_id is None + + +def test_load_credentials_from_file_external_account_authorized_user_bad_format(tmpdir): + filename = tmpdir.join("external_account_authorized_user_bad.json") + filename.write(json.dumps({"type": "external_account_authorized_user"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match( + "Failed to load external account authorized user credentials from {}".format( + str(filename) + ) + ) + + @pytest.mark.parametrize("quota_project_id", [None, "project-foo"]) @LOAD_FILE_PATCH def test__get_explicit_environ_credentials(load, quota_project_id, monkeypatch): diff --git a/tests/test_external_account_authorized_user.py b/tests/test_external_account_authorized_user.py new file mode 100644 index 000000000..49c34a9a4 --- /dev/null +++ b/tests/test_external_account_authorized_user.py @@ -0,0 +1,463 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json + +import mock +import pytest # type: ignore +from six.moves import http_client + +from google.auth import exceptions +from google.auth import external_account_authorized_user +from google.auth import transport + +TOKEN_URL = "https://sts.googleapis.com/v1/token" +TOKEN_INFO_URL = "https://sts.googleapis.com/v1/introspect" +REVOKE_URL = "https://sts.googleapis.com/v1/revoke" +PROJECT_NUMBER = "123456" +QUOTA_PROJECT_ID = "654321" +POOL_ID = "POOL_ID" +PROVIDER_ID = "PROVIDER_ID" +AUDIENCE = ( + "//iam.googleapis.com/projects/{}" + "/locations/global/workloadIdentityPools/{}" + "/providers/{}" +).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID) +REFRESH_TOKEN = "REFRESH_TOKEN" +NEW_REFRESH_TOKEN = "NEW_REFRESH_TOKEN" +ACCESS_TOKEN = "ACCESS_TOKEN" +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password". +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" + + +class TestCredentials(object): + @classmethod + def make_credentials( + cls, + audience=AUDIENCE, + refresh_token=REFRESH_TOKEN, + token_url=TOKEN_URL, + token_info_url=TOKEN_INFO_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + **kwargs + ): + return external_account_authorized_user.Credentials( + audience=audience, + refresh_token=refresh_token, + token_url=token_url, + token_info_url=token_info_url, + client_id=client_id, + client_secret=client_secret, + **kwargs + ) + + @classmethod + def make_mock_request(cls, status=http_client.OK, data=None): + # STS token exchange request. + token_response = mock.create_autospec(transport.Response, instance=True) + token_response.status = status + token_response.data = json.dumps(data).encode("utf-8") + responses = [token_response] + + request = mock.create_autospec(transport.Request) + request.side_effect = responses + + return request + + def test_default_state(self): + creds = self.make_credentials() + + assert not creds.expiry + assert not creds.expired + assert not creds.token + assert not creds.valid + assert not creds.requires_scopes + assert creds.is_user + + def test_basic_create(self): + creds = external_account_authorized_user.Credentials( + token=ACCESS_TOKEN, expiry=datetime.datetime.max + ) + + assert creds.expiry == datetime.datetime.max + assert not creds.expired + assert creds.token == ACCESS_TOKEN + assert creds.valid + assert not creds.requires_scopes + assert creds.is_user + + def test_stunted_create(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials(token=None, refresh_token=None) + + assert excinfo.match(r"Either `refresh_token` or `token` should be set") + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_auth_success(self, utcnow): + request = self.make_mock_request( + status=http_client.OK, + data={"access_token": ACCESS_TOKEN, "expires_in": 3600}, + ) + creds = self.make_credentials() + + creds.refresh(request) + + assert creds.expiry == utcnow() + datetime.timedelta(seconds=3600) + assert not creds.expired + assert creds.token == ACCESS_TOKEN + assert creds.valid + assert not creds.requires_scopes + assert creds.is_user + assert creds._refresh_token == REFRESH_TOKEN + + request.assert_called_once_with( + url=TOKEN_URL, + method="POST", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + }, + body=("grant_type=refresh_token&refresh_token=" + REFRESH_TOKEN).encode( + "UTF-8" + ), + ) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_auth_success_new_refresh_token(self, utcnow): + request = self.make_mock_request( + status=http_client.OK, + data={ + "access_token": ACCESS_TOKEN, + "expires_in": 3600, + "refresh_token": NEW_REFRESH_TOKEN, + }, + ) + creds = self.make_credentials() + + creds.refresh(request) + + assert creds.expiry == utcnow() + datetime.timedelta(seconds=3600) + assert not creds.expired + assert creds.token == ACCESS_TOKEN + assert creds.valid + assert not creds.requires_scopes + assert creds.is_user + assert creds._refresh_token == NEW_REFRESH_TOKEN + + request.assert_called_once_with( + url=TOKEN_URL, + method="POST", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + }, + body=("grant_type=refresh_token&refresh_token=" + REFRESH_TOKEN).encode( + "UTF-8" + ), + ) + + def test_refresh_auth_failure(self): + request = self.make_mock_request( + status=http_client.BAD_REQUEST, + data={ + "error": "invalid_request", + "error_description": "Invalid subject token", + "error_uri": "https://tools.ietf.org/html/rfc6749", + }, + ) + creds = self.make_credentials() + + with pytest.raises(exceptions.OAuthError) as excinfo: + creds.refresh(request) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + + assert not creds.expiry + assert not creds.expired + assert not creds.token + assert not creds.valid + assert not creds.requires_scopes + assert creds.is_user + + request.assert_called_once_with( + url=TOKEN_URL, + method="POST", + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + }, + body=("grant_type=refresh_token&refresh_token=" + REFRESH_TOKEN).encode( + "UTF-8" + ), + ) + + def test_refresh_without_refresh_token(self): + request = self.make_mock_request() + creds = self.make_credentials(refresh_token=None, token=ACCESS_TOKEN) + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret." + ) + + assert not creds.expiry + assert not creds.expired + assert not creds.requires_scopes + assert creds.is_user + + request.assert_not_called() + + def test_refresh_without_token_url(self): + request = self.make_mock_request() + creds = self.make_credentials(token_url=None) + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret." + ) + + assert not creds.expiry + assert not creds.expired + assert not creds.token + assert not creds.valid + assert not creds.requires_scopes + assert creds.is_user + + request.assert_not_called() + + def test_refresh_without_client_id(self): + request = self.make_mock_request() + creds = self.make_credentials(client_id=None) + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret." + ) + + assert not creds.expiry + assert not creds.expired + assert not creds.token + assert not creds.valid + assert not creds.requires_scopes + assert creds.is_user + + request.assert_not_called() + + def test_refresh_without_client_secret(self): + request = self.make_mock_request() + creds = self.make_credentials(client_secret=None) + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(request) + + assert excinfo.match( + r"The credentials do not contain the necessary fields need to refresh the access token. You must specify refresh_token, token_url, client_id, and client_secret." + ) + + assert not creds.expiry + assert not creds.expired + assert not creds.token + assert not creds.valid + assert not creds.requires_scopes + assert creds.is_user + + request.assert_not_called() + + def test_info(self): + creds = self.make_credentials() + info = creds.info + + assert info["audience"] == AUDIENCE + assert info["refresh_token"] == REFRESH_TOKEN + assert info["token_url"] == TOKEN_URL + assert info["token_info_url"] == TOKEN_INFO_URL + assert info["client_id"] == CLIENT_ID + assert info["client_secret"] == CLIENT_SECRET + assert "token" not in info + assert "expiry" not in info + assert "revoke_url" not in info + assert "quota_project_id" not in info + + def test_info_full(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + info = creds.info + + assert info["audience"] == AUDIENCE + assert info["refresh_token"] == REFRESH_TOKEN + assert info["token_url"] == TOKEN_URL + assert info["token_info_url"] == TOKEN_INFO_URL + assert info["client_id"] == CLIENT_ID + assert info["client_secret"] == CLIENT_SECRET + assert info["token"] == ACCESS_TOKEN + assert info["expiry"] == datetime.datetime.min.isoformat() + "Z" + assert info["revoke_url"] == REVOKE_URL + assert info["quota_project_id"] == QUOTA_PROJECT_ID + + def test_to_json(self): + creds = self.make_credentials() + json_info = creds.to_json() + info = json.loads(json_info) + + assert info["audience"] == AUDIENCE + assert info["refresh_token"] == REFRESH_TOKEN + assert info["token_url"] == TOKEN_URL + assert info["token_info_url"] == TOKEN_INFO_URL + assert info["client_id"] == CLIENT_ID + assert info["client_secret"] == CLIENT_SECRET + assert "token" not in info + assert "expiry" not in info + assert "revoke_url" not in info + assert "quota_project_id" not in info + + def test_to_json_full(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + json_info = creds.to_json() + info = json.loads(json_info) + + assert info["audience"] == AUDIENCE + assert info["refresh_token"] == REFRESH_TOKEN + assert info["token_url"] == TOKEN_URL + assert info["token_info_url"] == TOKEN_INFO_URL + assert info["client_id"] == CLIENT_ID + assert info["client_secret"] == CLIENT_SECRET + assert info["token"] == ACCESS_TOKEN + assert info["expiry"] == datetime.datetime.min.isoformat() + "Z" + assert info["revoke_url"] == REVOKE_URL + assert info["quota_project_id"] == QUOTA_PROJECT_ID + + def test_to_json_full_with_strip(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + json_info = creds.to_json(strip=["token", "expiry"]) + info = json.loads(json_info) + + assert info["audience"] == AUDIENCE + assert info["refresh_token"] == REFRESH_TOKEN + assert info["token_url"] == TOKEN_URL + assert info["token_info_url"] == TOKEN_INFO_URL + assert info["client_id"] == CLIENT_ID + assert info["client_secret"] == CLIENT_SECRET + assert "token" not in info + assert "expiry" not in info + assert info["revoke_url"] == REVOKE_URL + assert info["quota_project_id"] == QUOTA_PROJECT_ID + + def test_get_project_id(self): + creds = self.make_credentials() + assert creds.get_project_id() is None + + def test_with_quota_project(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + new_creds = creds.with_quota_project(QUOTA_PROJECT_ID) + assert new_creds._audience == creds._audience + assert new_creds._refresh_token == creds._refresh_token + assert new_creds._token_url == creds._token_url + assert new_creds._token_info_url == creds._token_info_url + assert new_creds._client_id == creds._client_id + assert new_creds._client_secret == creds._client_secret + assert new_creds.token == creds.token + assert new_creds.expiry == creds.expiry + assert new_creds._revoke_url == creds._revoke_url + assert new_creds._quota_project_id == QUOTA_PROJECT_ID + + def test_with_token_uri(self): + creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + new_creds = creds.with_token_uri("https://google.com") + assert new_creds._audience == creds._audience + assert new_creds._refresh_token == creds._refresh_token + assert new_creds._token_url == "https://google.com" + assert new_creds._token_info_url == creds._token_info_url + assert new_creds._client_id == creds._client_id + assert new_creds._client_secret == creds._client_secret + assert new_creds.token == creds.token + assert new_creds.expiry == creds.expiry + assert new_creds._revoke_url == creds._revoke_url + assert new_creds._quota_project_id == creds._quota_project_id + + def test_from_file_required_options_only(self, tmpdir): + from_creds = self.make_credentials() + config_file = tmpdir.join("config.json") + config_file.write(from_creds.to_json()) + creds = external_account_authorized_user.Credentials.from_file(str(config_file)) + + assert isinstance(creds, external_account_authorized_user.Credentials) + assert creds._audience == AUDIENCE + assert creds._refresh_token == REFRESH_TOKEN + assert creds._token_url == TOKEN_URL + assert creds._token_info_url == TOKEN_INFO_URL + assert creds._client_id == CLIENT_ID + assert creds._client_secret == CLIENT_SECRET + assert creds.token is None + assert creds.expiry is None + assert creds._revoke_url is None + assert creds._quota_project_id is None + + def test_from_file_full_options(self, tmpdir): + from_creds = self.make_credentials( + token=ACCESS_TOKEN, + expiry=datetime.datetime.min, + revoke_url=REVOKE_URL, + quota_project_id=QUOTA_PROJECT_ID, + ) + config_file = tmpdir.join("config.json") + config_file.write(from_creds.to_json()) + creds = external_account_authorized_user.Credentials.from_file(str(config_file)) + + assert isinstance(creds, external_account_authorized_user.Credentials) + assert creds._audience == AUDIENCE + assert creds._refresh_token == REFRESH_TOKEN + assert creds._token_url == TOKEN_URL + assert creds._token_info_url == TOKEN_INFO_URL + assert creds._client_id == CLIENT_ID + assert creds._client_secret == CLIENT_SECRET + assert creds.token == ACCESS_TOKEN + assert creds.expiry == datetime.datetime.min + assert creds._revoke_url == REVOKE_URL + assert creds._quota_project_id == QUOTA_PROJECT_ID