Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental prefect-aws bundle steps #17201

Merged
merged 10 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 10 additions & 61 deletions src/integrations/prefect-aws/prefect_aws/deployments/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
compatible services.
"""

from __future__ import annotations

from pathlib import Path, PurePosixPath
from typing import Dict, Optional
from typing import Any, Optional

import boto3
from botocore.client import Config
from typing_extensions import TypedDict

from prefect._internal.compatibility.deprecated import deprecated_callable
from prefect.utilities.filesystem import filter_files, relative_path_to_current_platform
from prefect_aws.s3 import get_s3_client


class PushToS3Output(TypedDict):
Expand Down Expand Up @@ -44,16 +45,16 @@ class PullProjectFromS3Output(PullFromS3Output):


@deprecated_callable(start_date="Jun 2023", help="Use `push_to_s3` instead.")
def push_project_to_s3(*args, **kwargs):
def push_project_to_s3(*args: Any, **kwargs: Any):
"""Deprecated. Use `push_to_s3` instead."""
push_to_s3(*args, **kwargs)


def push_to_s3(
bucket: str,
folder: str,
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
credentials: Optional[dict[str, Any]] = None,
client_parameters: Optional[dict[str, Any]] = None,
ignore_file: Optional[str] = ".prefectignore",
) -> PushToS3Output:
"""
Expand Down Expand Up @@ -124,16 +125,16 @@ def push_to_s3(


@deprecated_callable(start_date="Jun 2023", help="Use `pull_from_s3` instead.")
def pull_project_from_s3(*args, **kwargs):
def pull_project_from_s3(*args: Any, **kwargs: Any):
"""Deprecated. Use `pull_from_s3` instead."""
pull_from_s3(*args, **kwargs)


def pull_from_s3(
bucket: str,
folder: str,
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
credentials: Optional[dict[str, Any]] = None,
client_parameters: Optional[dict[str, Any]] = None,
) -> PullFromS3Output:
"""
Pulls the contents of an S3 bucket folder to the current working directory.
Expand Down Expand Up @@ -196,55 +197,3 @@ def pull_from_s3(
"folder": folder,
"directory": str(local_path),
}


def get_s3_client(
credentials: Optional[Dict] = None,
client_parameters: Optional[Dict] = None,
) -> dict:
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}

# Get credentials from credentials (regardless if block or not)
aws_access_key_id = credentials.get(
"aws_access_key_id", credentials.get("minio_root_user", None)
)
aws_secret_access_key = credentials.get(
"aws_secret_access_key", credentials.get("minio_root_password", None)
)
aws_session_token = credentials.get("aws_session_token", None)

# Get remaining session info from credentials, or client_parameters
profile_name = credentials.get(
"profile_name", client_parameters.get("profile_name", None)
)
region_name = credentials.get(
"region_name", client_parameters.get("region_name", None)
)

# Get additional info from client_parameters, otherwise credentials input (if block)
aws_client_parameters = credentials.get("aws_client_parameters", client_parameters)
api_version = aws_client_parameters.get("api_version", None)
endpoint_url = aws_client_parameters.get("endpoint_url", None)
use_ssl = aws_client_parameters.get("use_ssl", True)
verify = aws_client_parameters.get("verify", None)
config_params = aws_client_parameters.get("config", {})
config = Config(**config_params)

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=profile_name,
region_name=region_name,
)
return session.client(
"s3",
api_version=api_version,
endpoint_url=endpoint_url,
use_ssl=use_ssl,
verify=verify,
config=config,
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import os
import tempfile
from pathlib import Path

import typer
from botocore.exceptions import ClientError
from pydantic_core import from_json

from prefect.runner import Runner
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect_aws.credentials import AwsCredentials

from .types import AwsCredentialsBlockName, S3Bucket, S3Key


def download_bundle_from_s3(
bucket: S3Bucket,
key: S3Key,
output_dir: str | None = None,
aws_credentials_block_name: AwsCredentialsBlockName | None = None,
) -> dict[str, str]:
"""
Downloads a bundle from an S3 bucket.

Args:
bucket: S3 bucket name
key: S3 object key
output_dir: Local directory to save the bundle (if None, uses a temp directory)
aws_credentials_block_name: Name of the AWS credentials block to use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two notes:

  • worth calling out if block is None, credentials inferred from environment
  • prob worth doc'ing the return value (what is generally in the dict?)

"""

if aws_credentials_block_name:
aws_credentials = AwsCredentials.load(aws_credentials_block_name)
else:
aws_credentials = AwsCredentials()

s3 = aws_credentials.get_s3_client()

output_dir = output_dir or tempfile.mkdtemp(prefix="prefect-bundle-")
Path(output_dir).mkdir(parents=True, exist_ok=True)

local_path = Path(output_dir) / os.path.basename(key)

try:
s3.download_file(bucket, key, local_path)
return {"local_path": local_path}
except ClientError as e:
raise RuntimeError(f"Failed to download bundle from S3: {e}")


def execute_bundle_from_s3(
bucket: S3Bucket,
key: S3Key,
aws_credentials_block_name: AwsCredentialsBlockName | None = None,
) -> None:
"""
Downloads a bundle from S3 and executes it.

This step:
1. Downloads the bundle from S3
2. Extracts and deserializes the bundle
3. Executes the flow in a subprocess

Args:
bucket: S3 bucket name
key: S3 object key
aws_credentials_block_name: Name of the AWS credentials block to use
"""
download_result = download_bundle_from_s3(
bucket=bucket,
key=key,
aws_credentials_block_name=aws_credentials_block_name,
)

bundle_data = from_json(Path(download_result["local_path"]).read_bytes())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the reliance on the local_path key, could be worth elevating the return type of download_bundle to a TypedDict or dataclass


run_coro_as_sync(Runner().execute_bundle(bundle_data))


if __name__ == "__main__":
typer.run(execute_bundle_from_s3)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Annotated

import typer

S3Bucket = Annotated[str, typer.Option("--bucket")]
S3Key = Annotated[str, typer.Option("--key")]
AwsCredentialsBlockName = Annotated[str, typer.Option("--aws-credentials-block-name")]
LocalFilepath = Annotated[str, typer.Argument()]
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
S3 bundle steps for Prefect.
These steps allow uploading and downloading flow/task bundles to and from S3.
"""

from __future__ import annotations

from pathlib import Path

import typer
from botocore.exceptions import ClientError

from prefect_aws.credentials import AwsCredentials

from .types import (
AwsCredentialsBlockName,
LocalFilepath,
S3Bucket,
S3Key,
)


def upload_bundle_to_s3(
local_filepath: LocalFilepath,
bucket: S3Bucket,
key: S3Key,
aws_credentials_block_name: AwsCredentialsBlockName | None = None,
) -> dict[str, str]:
"""
Uploads a bundle file to an S3 bucket.

Args:
local_filepath: Local path to the bundle file
bucket: S3 bucket name
key: S3 object key (if None, uses the bundle filename)
aws_credentials_block_name: Name of the AWS credentials block to use

Returns:
Dictionary containing the bucket, key, and S3 URL of the uploaded bundle
"""
filepath = Path(local_filepath)
if not filepath.exists():
raise ValueError(f"Bundle file not found: {filepath}")

key = key or filepath.name

# Set up S3 client with credentials if provided
if aws_credentials_block_name:
aws_credentials = AwsCredentials.load(aws_credentials_block_name)
else:
aws_credentials = AwsCredentials()

s3 = aws_credentials.get_s3_client()

try:
s3.upload_file(str(filepath), bucket, key)
return {"bucket": bucket, "key": key, "url": f"s3://{bucket}/{key}"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same super minor point: could be worth a structured return object

except ClientError as e:
raise RuntimeError(f"Failed to upload bundle to S3: {e}")


if __name__ == "__main__":
typer.run(upload_bundle_to_s3)
54 changes: 54 additions & 0 deletions src/integrations/prefect-aws/prefect_aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, get_args

import boto3
from botocore.client import Config
from botocore.paginate import PageIterator
from botocore.response import StreamingBody
from pydantic import Field, field_validator
Expand All @@ -23,6 +25,58 @@
from prefect_aws.client_parameters import AwsClientParameters


def get_s3_client(
credentials: Optional[dict[str, Any]] = None,
client_parameters: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
if credentials is None:
credentials = {}
if client_parameters is None:
client_parameters = {}

# Get credentials from credentials (regardless if block or not)
aws_access_key_id = credentials.get(
"aws_access_key_id", credentials.get("minio_root_user", None)
)
aws_secret_access_key = credentials.get(
"aws_secret_access_key", credentials.get("minio_root_password", None)
)
aws_session_token = credentials.get("aws_session_token", None)

# Get remaining session info from credentials, or client_parameters
profile_name = credentials.get(
"profile_name", client_parameters.get("profile_name", None)
)
region_name = credentials.get(
"region_name", client_parameters.get("region_name", None)
)

# Get additional info from client_parameters, otherwise credentials input (if block)
aws_client_parameters = credentials.get("aws_client_parameters", client_parameters)
api_version = aws_client_parameters.get("api_version", None)
endpoint_url = aws_client_parameters.get("endpoint_url", None)
use_ssl = aws_client_parameters.get("use_ssl", True)
verify = aws_client_parameters.get("verify", None)
config_params = aws_client_parameters.get("config", {})
config = Config(**config_params)

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
profile_name=profile_name,
region_name=region_name,
)
return session.client(
"s3",
api_version=api_version,
endpoint_url=endpoint_url,
use_ssl=use_ssl,
verify=verify,
config=config,
)


@task
async def adownload_from_bucket(
bucket: str,
Expand Down
Loading
Loading