Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding domain-wide delegation flow in impersonated credential #1624

Merged
merged 14 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
+ "/serviceAccounts/{}:signBlob"
)

_IAM_SIGNJWT_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signJwt"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
Expand Down
101 changes: 100 additions & 1 deletion google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@
from google.auth import iam
from google.auth import jwt
from google.auth import metrics
from google.oauth2 import _client


_REFRESH_ERROR = "Unable to acquire impersonated credentials"

_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds

_GOOGLE_OAUTH2_TOKEN_ENDPOINT = "https://oauth2.googleapis.com/token"


def _make_iam_token_request(
request,
Expand Down Expand Up @@ -177,6 +180,7 @@ def __init__(
target_principal,
target_scopes,
delegates=None,
subject=None,
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
quota_project_id=None,
iam_endpoint_override=None,
Expand Down Expand Up @@ -204,9 +208,12 @@ def __init__(
quota_project_id (Optional[str]): The project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
iam_endpoint_override (Optiona[str]): The full IAM endpoint override
iam_endpoint_override (Optional[str]): The full IAM endpoint override
with the target_principal embedded. This is useful when supporting
impersonation with regional endpoints.
subject (Optional[str]): sub field of a JWT. This field should only be set
if you wish to impersonate as a user. This feature is useful when
using domain wide delegation.
"""

super(Credentials, self).__init__()
Expand All @@ -231,6 +238,7 @@ def __init__(
self._target_principal = target_principal
self._target_scopes = target_scopes
self._delegates = delegates
self._subject = subject
self._lifetime = lifetime or _DEFAULT_TOKEN_LIFETIME_SECS
self.token = None
self.expiry = _helpers.utcnow()
Expand Down Expand Up @@ -275,6 +283,39 @@ def _update_token(self, request):
# Apply the source credentials authentication info.
self._source_credentials.apply(headers)

# If a subject is specified a domain-wide delegation auth-flow is initiated
# to impersonate as the provided subject (user).
if self._subject:
if self.universe_domain != credentials.DEFAULT_UNIVERSE_DOMAIN:
raise exceptions.GoogleAuthError(
"Domain-wide delegation is not supported in universes other "
+ "than googleapis.com"
)

now = _helpers.utcnow()
payload = {
"iss": self._target_principal,
"scope": _helpers.scopes_to_string(self._target_scopes or ()),
"sub": self._subject,
"aud": _GOOGLE_OAUTH2_TOKEN_ENDPOINT,
"iat": _helpers.datetime_to_secs(now),
"exp": _helpers.datetime_to_secs(now) + _DEFAULT_TOKEN_LIFETIME_SECS,
}

assertion = _sign_jwt_request(
request=request,
principal=self._target_principal,
headers=headers,
payload=payload,
delegates=self._delegates,
)

self.token, self.expiry, _ = _client.jwt_grant(
request, _GOOGLE_OAUTH2_TOKEN_ENDPOINT, assertion
)

return

self.token, self.expiry = _make_iam_token_request(
request=request,
principal=self._target_principal,
Expand Down Expand Up @@ -478,3 +519,61 @@ def refresh(self, request):
self.expiry = datetime.utcfromtimestamp(
jwt.decode(id_token, verify=False)["exp"]
)


def _sign_jwt_request(request, principal, headers, payload, delegates=[]):
"""Makes a request to the Google Cloud IAM service to sign a JWT using a
service account's system-managed private key.
Args:
request (Request): The Request object to use.
principal (str): The principal to request an access token for.
headers (Mapping[str, str]): Map of headers to transmit.
payload (Mapping[str, str]): The JWT payload to sign. Must be a
serialized JSON object that contains a JWT Claims Set.
delegates (Sequence[str]): The chained list of delegates required
to grant the final access_token. If set, the sequence of
identities must have "Service Account Token Creator" capability
granted to the prceeding identity. For example, if set to
[serviceAccountB, serviceAccountC], the source_credential
must have the Token Creator role on serviceAccountB.
serviceAccountB must have the Token Creator on
serviceAccountC.
Finally, C must have Token Creator on target_principal.
If left unset, source_credential must have that role on
target_principal.

Raises:
google.auth.exceptions.TransportError: Raised if there is an underlying
HTTP connection error
google.auth.exceptions.RefreshError: Raised if the impersonated
credentials are not available. Common reasons are
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam._IAM_SIGNJWT_ENDPOINT.format(principal)

body = {"delegates": delegates, "payload": json.dumps(payload)}
body = json.dumps(body).encode("utf-8")

response = request(url=iam_endpoint, method="POST", headers=headers, body=body)

# support both string and bytes type response.data
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
else response.data
)

if response.status != http_client.OK:
raise exceptions.RefreshError(_REFRESH_ERROR, response_body)

try:
jwt_response = json.loads(response_body)
signed_jwt = jwt_response["signedJwt"]
return signed_jwt

except (KeyError, ValueError) as caught_exc:
new_exc = exceptions.RefreshError(
"{}: No signed JWT in response.".format(_REFRESH_ERROR), response_body
)
raise new_exc from caught_exc
120 changes: 120 additions & 0 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ def mock_donor_credentials():
yield grant


@pytest.fixture
def mock_dwd_credentials():
with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
grant.return_value = (
"1/fFAGRNJasdfz70BzhT3Zg",
_helpers.utcnow() + datetime.timedelta(seconds=500),
{},
)
yield grant


class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
Expand Down Expand Up @@ -123,6 +134,7 @@ def make_credentials(
source_credentials=SOURCE_CREDENTIALS,
lifetime=LIFETIME,
target_principal=TARGET_PRINCIPAL,
subject=None,
iam_endpoint_override=None,
):

Expand All @@ -132,6 +144,7 @@ def make_credentials(
target_scopes=self.TARGET_SCOPES,
delegates=self.DELEGATES,
lifetime=lifetime,
subject=subject,
iam_endpoint_override=iam_endpoint_override,
)

Expand Down Expand Up @@ -238,6 +251,28 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_with_subject_success(self, use_data_bytes, mock_dwd_credentials):
credentials = self.make_credentials(subject="[email protected]", lifetime=None)

response_body = {"signedJwt": "example_signed_jwt"}

request = self.make_request(
data=json.dumps(response_body),
status=http_client.OK,
use_data_bytes=use_data_bytes,
)

with mock.patch(
"google.auth.metrics.token_request_access_token_impersonate",
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
):
credentials.refresh(request)

assert credentials.valid
assert not credentials.expired
assert credentials.token == "1/fFAGRNJasdfz70BzhT3Zg"

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials):
source_credentials = service_account.Credentials(
Expand Down Expand Up @@ -418,6 +453,33 @@ def test_refresh_failure_http_error(self, mock_donor_credentials):
assert not credentials.valid
assert credentials.expired

def test_refresh_failure_subject_with_nondefault_domain(
self, mock_donor_credentials
):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
source_credentials=source_credentials, subject="[email protected]"
)

expire_time = (_helpers.utcnow().replace(microsecond=0)).isoformat("T") + "Z"
response_body = {"accessToken": "token", "expireTime": expire_time}
request = self.make_request(
data=json.dumps(response_body), status=http_client.OK
)

with pytest.raises(exceptions.GoogleAuthError) as excinfo:
credentials.refresh(request)

assert excinfo.match(
"Domain-wide delegation is not supported in universes other "
+ "than googleapis.com"
)

assert not credentials.valid
assert credentials.expired

def test_expired(self):
credentials = self.make_credentials(lifetime=None)
assert credentials.expired
Expand Down Expand Up @@ -810,3 +872,61 @@ def test_id_token_with_quota_project(
id_creds.refresh(request)

assert id_creds.quota_project_id == "project-foo"

def test_sign_jwt_request_success(self):
principal = "[email protected]"
expected_signed_jwt = "correct_signed_jwt"

response_body = {"keyId": "1", "signedJwt": expected_signed_jwt}
request = self.make_request(
data=json.dumps(response_body), status=http_client.OK
)

signed_jwt = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert signed_jwt == expected_signed_jwt
request.assert_called_once_with(
url="https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/[email protected]:signJwt",
method="POST",
headers={},
body=json.dumps({"delegates": [], "payload": json.dumps({})}).encode(
"utf-8"
),
)

def test_sign_jwt_request_http_error(self):
principal = "[email protected]"

request = self.make_request(
data="error_message", status=http_client.BAD_REQUEST
)

with pytest.raises(exceptions.RefreshError) as excinfo:
_ = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

assert excinfo.value.args[0] == "Unable to acquire impersonated credentials"
assert excinfo.value.args[1] == "error_message"

def test_sign_jwt_request_invalid_response_error(self):
principal = "[email protected]"

request = self.make_request(data="invalid_data", status=http_client.OK)

with pytest.raises(exceptions.RefreshError) as excinfo:
_ = impersonated_credentials._sign_jwt_request(
request=request, principal=principal, headers={}, payload={}
)

assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

assert (
excinfo.value.args[0]
== "Unable to acquire impersonated credentials: No signed JWT in response."
)
assert excinfo.value.args[1] == "invalid_data"