Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for asynchronous AuthorizedSession api #1577

Merged
merged 13 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions google/auth/_exponential_backoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import random
import time

Expand All @@ -38,9 +39,8 @@
"""


class ExponentialBackoff:
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
class _BaseExponentialBackoff:
"""An exponential backoff iterator base class.

Args:
total_attempts Optional[int]:
Expand Down Expand Up @@ -84,9 +84,40 @@ def __init__(
self._multiplier = multiplier
self._backoff_count = 0

def __iter__(self):
@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
return self._backoff_count

def _reset(self):
self._backoff_count = 0
self._current_wait_in_seconds = self._initial_wait_seconds

def _calculate_jitter(self):
jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)

return jitter


class ExponentialBackoff(_BaseExponentialBackoff):
"""An exponential backoff iterator. This can be used in a for loop to
perform requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(ExponentialBackoff, self).__init__(*args, **kwargs)

def __iter__(self):
self._reset()
return self

def __next__(self):
Expand All @@ -97,23 +128,37 @@ def __next__(self):
if self._backoff_count <= 1:
return self._backoff_count

jitter_variance = self._current_wait_in_seconds * self._randomization_factor
jitter = random.uniform(
self._current_wait_in_seconds - jitter_variance,
self._current_wait_in_seconds + jitter_variance,
)
jitter = self._calculate_jitter()

time.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count

@property
def total_attempts(self):
"""The total amount of backoff attempts that will be made."""
return self._total_attempts

@property
def backoff_count(self):
"""The current amount of backoff attempts that have been made."""
class AsyncExponentialBackoff(_BaseExponentialBackoff):
"""An async exponential backoff iterator. This can be used in a for loop to
perform async requests with exponential backoff.
"""

def __init__(self, *args, **kwargs):
super(AsyncExponentialBackoff, self).__init__(*args, **kwargs)

def __aiter__(self):
self._reset()
return self

async def __anext__(self):
if self._backoff_count >= self._total_attempts:
raise StopAsyncIteration
self._backoff_count += 1

if self._backoff_count <= 1:
return self._backoff_count

jitter = self._calculate_jitter()

await asyncio.sleep(jitter)

self._current_wait_in_seconds *= self._multiplier
return self._backoff_count
144 changes: 144 additions & 0 deletions google/auth/aio/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transport - Asynchronous HTTP client library support.

:mod:`google.auth.aio` is designed to work with various asynchronous client libraries such
as aiohttp. In order to work across these libraries with different
interfaces some abstraction is needed.

This module provides two interfaces that are implemented by transport adapters
to support HTTP libraries. :class:`Request` defines the interface expected by
:mod:`google.auth` to make asynchronous requests. :class:`Response` defines the interface
for the return value of :class:`Request`.
"""

import abc
from typing import AsyncGenerator, Mapping, Optional

import google.auth.transport


_DEFAULT_TIMEOUT_SECONDS = 180

DEFAULT_RETRYABLE_STATUS_CODES = google.auth.transport.DEFAULT_RETRYABLE_STATUS_CODES
"""Sequence[int]: HTTP status codes indicating a request can be retried.
"""


DEFAULT_MAX_RETRY_ATTEMPTS = 3
"""int: How many times to retry a request."""


class Response(metaclass=abc.ABCMeta):
"""Asynchronous HTTP Response Interface."""

@property
@abc.abstractmethod
def status_code(self) -> int:
"""
The HTTP response status code.

Returns:
int: The HTTP response status code.

"""
raise NotImplementedError("status_code must be implemented.")

@property
@abc.abstractmethod
def headers(self) -> Mapping[str, str]:
"""The HTTP response headers.

Returns:
Mapping[str, str]: The HTTP response headers.
"""
raise NotImplementedError("headers must be implemented.")

@abc.abstractmethod
async def content(self, chunk_size: int) -> AsyncGenerator[bytes, None]:
"""The raw response content.

Args:
chunk_size (int): The size of each chunk.

Yields:
AsyncGenerator[bytes, None]: An asynchronous generator yielding
response chunks as bytes.
"""
raise NotImplementedError("content must be implemented.")

@abc.abstractmethod
async def read(self) -> bytes:
"""Read the entire response content as bytes.

Returns:
bytes: The entire response content.
"""
raise NotImplementedError("read must be implemented.")

@abc.abstractmethod
async def close(self):
"""Close the response after it is fully consumed to resource."""
raise NotImplementedError("close must be implemented.")


class Request(metaclass=abc.ABCMeta):
"""Interface for a callable that makes HTTP requests.

Specific transport implementations should provide an implementation of
this that adapts their specific request / response API.

.. automethod:: __call__
"""

@abc.abstractmethod
async def __call__(
self,
url: str,
method: str,
body: Optional[bytes],
headers: Optional[Mapping[str, str]],
timeout: float,
**kwargs
) -> Response:
"""Make an HTTP request.

Args:
url (str): The URI to be requested.
method (str): The HTTP method to use for the request. Defaults
to 'GET'.
body (Optional[bytes]): The payload / body in HTTP request.
headers (Mapping[str, str]): Request headers.
timeout (float): The number of seconds to wait for a
response from the server. If not specified or if None, the
transport-specific default timeout will be used.
kwargs: Additional arguments passed on to the transport's
request method.

Returns:
google.auth.aio.transport.Response: The HTTP response.

Raises:
google.auth.exceptions.TransportError: If any exception occurred.
"""
# pylint: disable=redundant-returns-doc, missing-raises-doc
# (pylint doesn't play well with abstract docstrings.)
raise NotImplementedError("__call__ must be implemented.")

async def close(self) -> None:
"""
Close the underlying session.
"""
raise NotImplementedError("close must be implemented.")
Loading