Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RefreshableSession object to session.py #4457

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 104 additions & 42 deletions boto3/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

import botocore.session
from botocore.client import Config
from botocore.credentials import (
DeferredRefreshableCredentials,
RefreshableCredentials,
)
from botocore.exceptions import (
DataNotFoundError,
NoCredentialsError,
Expand Down Expand Up @@ -69,17 +73,17 @@ def __init__(
self._session = botocore.session.get_session()

# Setup custom user-agent string if it isn't already customized
if self._session.user_agent_name == 'Botocore':
botocore_info = f'Botocore/{self._session.user_agent_version}'
if self._session.user_agent_name == "Botocore":
botocore_info = f"Botocore/{self._session.user_agent_version}"
if self._session.user_agent_extra:
self._session.user_agent_extra += ' ' + botocore_info
self._session.user_agent_extra += " " + botocore_info
else:
self._session.user_agent_extra = botocore_info
self._session.user_agent_name = 'Boto3'
self._session.user_agent_name = "Boto3"
self._session.user_agent_version = boto3.__version__

if profile_name is not None:
self._session.set_config_variable('profile', profile_name)
self._session.set_config_variable("profile", profile_name)

creds = (
aws_access_key_id,
Expand All @@ -100,40 +104,40 @@ def __init__(
)

if region_name is not None:
self._session.set_config_variable('region', region_name)
self._session.set_config_variable("region", region_name)

self.resource_factory = ResourceFactory(
self._session.get_component('event_emitter')
self._session.get_component("event_emitter")
)
self._setup_loader()
self._register_default_handlers()

def __repr__(self):
return '{}(region_name={})'.format(
return "{}(region_name={})".format(
self.__class__.__name__,
repr(self._session.get_config_variable('region')),
repr(self._session.get_config_variable("region")),
)

@property
def profile_name(self):
"""
The **read-only** profile name.
"""
return self._session.profile or 'default'
return self._session.profile or "default"

@property
def region_name(self):
"""
The **read-only** region name.
"""
return self._session.get_config_variable('region')
return self._session.get_config_variable("region")

@property
def events(self):
"""
The event emitter for a session
"""
return self._session.get_component('event_emitter')
return self._session.get_component("event_emitter")

@property
def available_profiles(self):
Expand All @@ -146,9 +150,9 @@ def _setup_loader(self):
"""
Setup loader paths so that we can load resources.
"""
self._loader = self._session.get_component('data_loader')
self._loader = self._session.get_component("data_loader")
self._loader.search_paths.append(
os.path.join(os.path.dirname(__file__), 'data')
os.path.join(os.path.dirname(__file__), "data")
)

def get_available_services(self):
Expand All @@ -169,7 +173,7 @@ def get_available_resources(self):
:rtype: list
:return: List of service names
"""
return self._loader.list_available_services(type_name='resources-1')
return self._loader.list_available_services(type_name="resources-1")

def get_available_partitions(self):
"""Lists the available partitions
Expand All @@ -180,7 +184,7 @@ def get_available_partitions(self):
return self._session.get_available_partitions()

def get_available_regions(
self, service_name, partition_name='aws', allow_non_regional=False
self, service_name, partition_name="aws", allow_non_regional=False
):
"""Lists the region and endpoint names of a particular partition.

Expand Down Expand Up @@ -418,7 +422,7 @@ def resource(
"""
try:
resource_model = self._loader.load_service_model(
service_name, 'resources-1', api_version
service_name, "resources-1", api_version
)
except UnknownServiceError:
available = self.get_available_resources()
Expand All @@ -431,10 +435,10 @@ def resource(
except DataNotFoundError:
# This is because we've provided an invalid API version.
available_api_versions = self._loader.list_api_versions(
service_name, 'resources-1'
service_name, "resources-1"
)
raise UnknownAPIVersionError(
service_name, api_version, ', '.join(available_api_versions)
service_name, api_version, ", ".join(available_api_versions)
)

if api_version is None:
Expand All @@ -454,7 +458,7 @@ def resource(
# and loader.determine_latest_version(..., 'resources-1')
# both load the same api version of the file.
api_version = self._loader.determine_latest_version(
service_name, 'resources-1'
service_name, "resources-1"
)

# Creating a new resource instance requires the low-level client
Expand All @@ -464,9 +468,9 @@ def resource(
if config is not None:
if config.user_agent_extra is None:
config = copy.deepcopy(config)
config.user_agent_extra = 'Resource'
config.user_agent_extra = "Resource"
else:
config = Config(user_agent_extra='Resource')
config = Config(user_agent_extra="Resource")
client = self.client(
service_name,
region_name=region_name,
Expand All @@ -486,7 +490,7 @@ def resource(
service_context = boto3.utils.ServiceContext(
service_name=service_name,
service_model=service_model,
resource_json_definitions=resource_model['resources'],
resource_json_definitions=resource_model["resources"],
service_waiter_model=boto3.utils.LazyLoadedWaiterModel(
self._session, service_name, api_version
),
Expand All @@ -495,7 +499,7 @@ def resource(
# Create the service resource class.
cls = self.resource_factory.load_from_definition(
resource_name=service_name,
single_resource_json_definition=resource_model['service'],
single_resource_json_definition=resource_model["service"],
service_context=service_context,
)

Expand All @@ -504,52 +508,52 @@ def resource(
def _register_default_handlers(self):
# S3 customizations
self._session.register(
'creating-client-class.s3',
"creating-client-class.s3",
boto3.utils.lazy_call(
'boto3.s3.inject.inject_s3_transfer_methods'
"boto3.s3.inject.inject_s3_transfer_methods"
),
)
self._session.register(
'creating-resource-class.s3.Bucket',
boto3.utils.lazy_call('boto3.s3.inject.inject_bucket_methods'),
"creating-resource-class.s3.Bucket",
boto3.utils.lazy_call("boto3.s3.inject.inject_bucket_methods"),
)
self._session.register(
'creating-resource-class.s3.Object',
boto3.utils.lazy_call('boto3.s3.inject.inject_object_methods'),
"creating-resource-class.s3.Object",
boto3.utils.lazy_call("boto3.s3.inject.inject_object_methods"),
)
self._session.register(
'creating-resource-class.s3.ObjectSummary',
"creating-resource-class.s3.ObjectSummary",
boto3.utils.lazy_call(
'boto3.s3.inject.inject_object_summary_methods'
"boto3.s3.inject.inject_object_summary_methods"
),
)

# DynamoDb customizations
self._session.register(
'creating-resource-class.dynamodb',
"creating-resource-class.dynamodb",
boto3.utils.lazy_call(
'boto3.dynamodb.transform.register_high_level_interface'
"boto3.dynamodb.transform.register_high_level_interface"
),
unique_id='high-level-dynamodb',
unique_id="high-level-dynamodb",
)
self._session.register(
'creating-resource-class.dynamodb.Table',
"creating-resource-class.dynamodb.Table",
boto3.utils.lazy_call(
'boto3.dynamodb.table.register_table_methods'
"boto3.dynamodb.table.register_table_methods"
),
unique_id='high-level-dynamodb-table',
unique_id="high-level-dynamodb-table",
)

# EC2 Customizations
self._session.register(
'creating-resource-class.ec2.ServiceResource',
boto3.utils.lazy_call('boto3.ec2.createtags.inject_create_tags'),
"creating-resource-class.ec2.ServiceResource",
boto3.utils.lazy_call("boto3.ec2.createtags.inject_create_tags"),
)

self._session.register(
'creating-resource-class.ec2.Instance',
"creating-resource-class.ec2.Instance",
boto3.utils.lazy_call(
'boto3.ec2.deletetags.inject_delete_tags',
"boto3.ec2.deletetags.inject_delete_tags",
event_emitter=self.events,
),
)
Expand All @@ -562,3 +566,61 @@ def _account_id_set_without_credentials(
elif access_key is None or secret_key is None:
return True
return False


class RefreshableSession(Session):
"""
A refreshable session stores configuration state and allows you to create service
clients and resources while automatically refreshing temporary security credentials.
Accepts all of the same parameters as the boto3.session.Session object.

:type assume_role_kwargs: dict
:param assume_role_kwargs: Required keyword arguments for the STS.Client.assume_role method.
:type defer_refresh: bool
:param defer_refresh: DeferredRefreshableCredentials or RefreshableCredentials.
:type sts_client_kwargs: dict
:param sts_client_kwargs: Optional keyword arguments for the STS.Client object.
"""

def __init__(
self,
assume_role_kwargs: dict,
defer_refresh: bool = True,
sts_client_kwargs: dict = None,
**kwargs,
):
super().__init__(**kwargs)
self.assume_role_kwargs = assume_role_kwargs
if sts_client_kwargs is not None:
self._sts_client = boto3.client(
service_name="sts", **sts_client_kwargs
)
else:
self._sts_client = boto3.client(service_name="sts")

# determining how exactly to refresh expired temporary credentials
if not defer_refresh:
self._session._credentials = (
RefreshableCredentials.create_from_metadata(
metadata=self._get_credentials(),
refresh_using=self._get_credentials,
method="sts-assume-role",
)
)
else:
self._session._credentials = DeferredRefreshableCredentials(
refresh_using=self._get_credentials, method="sts-assume-role"
)

def _get_credentials(self) -> dict:
_temporary_credentials = self._sts_client.assume_role(
**self.assume_role_kwargs
)["Credentials"]
return {
"access_key": _temporary_credentials.get("AccessKeyId"),
"secret_key": _temporary_credentials.get("SecretAccessKey"),
"token": _temporary_credentials.get("SessionToken"),
"expiry_time": _temporary_credentials.get(
"Expiration"
).isoformat(),
}