From 5cb7a4de23da4cffd5746043117879a6b07d32fc Mon Sep 17 00:00:00 2001 From: Giacomo Dell'omo <91473976+jack-dell@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:27:38 +0200 Subject: [PATCH] feat(redshift): Automatically add new DataFrame columns to Redshift tables during write operation (#2948) * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * fix: ruff formatting * fix: ruff formatting * fix: get redshift_types only if needed * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * chore: code style * chore: code style * chore: code style * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation --- awswrangler/redshift/_utils.py | 175 +++++++++++++++++++++++++-------- awswrangler/redshift/_write.py | 53 +++++++++- tests/unit/test_redshift.py | 169 +++++++++++++++++++++++++++++++ 3 files changed, 353 insertions(+), 44 deletions(-) diff --git a/awswrangler/redshift/_utils.py b/awswrangler/redshift/_utils.py index 8bb6c0543..6db432a67 100644 --- a/awswrangler/redshift/_utils.py +++ b/awswrangler/redshift/_utils.py @@ -106,6 +106,27 @@ def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: s return fields +def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]: + sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'" + _logger.debug("Executing select query:\n%s", sql) + cursor.execute(sql) + result: tuple[list[str]] = cursor.fetchall() + columns = ["".join(lst) for lst in result] + return columns + + +def _add_table_columns( + cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str] +) -> None: + for column_name, column_type in new_columns.items(): + sql = ( + f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}" + f"\nADD COLUMN {_identifier(column_name)} {column_type};" + ) + _logger.debug("Executing alter query:\n%s", sql) + cursor.execute(sql) + + def _does_table_exist(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" sql = ( @@ -128,6 +149,16 @@ def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = No return paths +def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str: + sql = f"SHOW {parameter_name}" + _logger.debug("Executing select query:\n%s", sql) + cursor.execute(sql) + result = cursor.fetchall() + status = str(result[0][0]) + _logger.debug(f"{parameter_name}='{status}'") + return status + + def _lock( cursor: "redshift_connector.Cursor", table_names: list[str], @@ -267,7 +298,90 @@ def _redshift_types_from_path( return redshift_types -def _create_table( # noqa: PLR0912,PLR0913,PLR0915 +def _get_rsh_columns_types( + df: pd.DataFrame | None, + path: str | list[str] | None, + index: bool, + dtype: dict[str, str] | None, + varchar_lengths_default: int, + varchar_lengths: dict[str, int] | None, + data_format: Literal["parquet", "orc", "csv"] = "parquet", + redshift_column_types: dict[str, str] | None = None, + parquet_infer_sampling: float = 1.0, + path_suffix: str | None = None, + path_ignore_suffix: str | list[str] | None = None, + manifest: bool | None = False, + use_threads: bool | int = True, + boto3_session: boto3.Session | None = None, + s3_additional_kwargs: dict[str, str] | None = None, +) -> dict[str, str]: + if df is not None: + redshift_types: dict[str, str] = _data_types.database_types_from_pandas( + df=df, + index=index, + dtype=dtype, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + converter_func=_data_types.pyarrow2redshift, + ) + _logger.debug("Converted redshift types from pandas: %s", redshift_types) + elif path is not None: + if manifest: + if not isinstance(path, str): + raise TypeError( + f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True; + must be a string""" + ) + path = _get_paths_from_manifest( + path=path, + boto3_session=boto3_session, + ) + + if data_format in ["parquet", "orc"]: + redshift_types = _redshift_types_from_path( + path=path, + data_format=data_format, # type: ignore[arg-type] + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + parquet_infer_sampling=parquet_infer_sampling, + path_suffix=path_suffix, + path_ignore_suffix=path_ignore_suffix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + ) + else: + if redshift_column_types is None: + raise ValueError( + "redshift_column_types is None. It must be specified for files formats other than Parquet or ORC." + ) + redshift_types = redshift_column_types + else: + raise ValueError("df and path are None. You MUST pass at least one.") + return redshift_types + + +def _add_new_table_columns( + cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str] +) -> None: + # Check if Redshift is configured as case sensitive or not + is_case_sensitive = False + if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [ + "on", + "true", + ]: + is_case_sensitive = True + + # If it is case-insensitive, convert all the DataFrame columns names to lowercase before performing the comparison + if is_case_sensitive is False: + redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()} + actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table)) + new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns} + + _add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns) + + +def _create_table( # noqa: PLR0913 df: pd.DataFrame | None, path: str | list[str] | None, con: "redshift_connector.Connection", @@ -336,49 +450,24 @@ def _create_table( # noqa: PLR0912,PLR0913,PLR0915 return table, schema diststyle = diststyle.upper() if diststyle else "AUTO" sortstyle = sortstyle.upper() if sortstyle else "COMPOUND" - if df is not None: - redshift_types: dict[str, str] = _data_types.database_types_from_pandas( - df=df, - index=index, - dtype=dtype, - varchar_lengths_default=varchar_lengths_default, - varchar_lengths=varchar_lengths, - converter_func=_data_types.pyarrow2redshift, - ) - _logger.debug("Converted redshift types from pandas: %s", redshift_types) - elif path is not None: - if manifest: - if not isinstance(path, str): - raise TypeError( - f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True; - must be a string""" - ) - path = _get_paths_from_manifest( - path=path, - boto3_session=boto3_session, - ) - if data_format in ["parquet", "orc"]: - redshift_types = _redshift_types_from_path( - path=path, - data_format=data_format, # type: ignore[arg-type] - varchar_lengths_default=varchar_lengths_default, - varchar_lengths=varchar_lengths, - parquet_infer_sampling=parquet_infer_sampling, - path_suffix=path_suffix, - path_ignore_suffix=path_ignore_suffix, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=s3_additional_kwargs, - ) - else: - if redshift_column_types is None: - raise ValueError( - "redshift_column_types is None. It must be specified for files formats other than Parquet or ORC." - ) - redshift_types = redshift_column_types - else: - raise ValueError("df and path are None. You MUST pass at least one.") + redshift_types = _get_rsh_columns_types( + df=df, + path=path, + index=index, + dtype=dtype, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + parquet_infer_sampling=parquet_infer_sampling, + path_suffix=path_suffix, + path_ignore_suffix=path_ignore_suffix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + data_format=data_format, + redshift_column_types=redshift_column_types, + manifest=manifest, + ) _validate_parameters( redshift_types=redshift_types, diststyle=diststyle, diff --git a/awswrangler/redshift/_write.py b/awswrangler/redshift/_write.py index 2a1d09296..c1aef4c74 100644 --- a/awswrangler/redshift/_write.py +++ b/awswrangler/redshift/_write.py @@ -13,7 +13,14 @@ from awswrangler._config import apply_configs from ._connect import _validate_connection -from ._utils import _create_table, _make_s3_auth_string, _upsert +from ._utils import ( + _add_new_table_columns, + _create_table, + _does_table_exist, + _get_rsh_columns_types, + _make_s3_auth_string, + _upsert, +) if TYPE_CHECKING: try: @@ -102,6 +109,7 @@ def to_sql( chunksize: int = 200, commit_transaction: bool = True, precombine_key: str | None = None, + add_new_columns: bool = False, ) -> None: """Write records stored in a DataFrame into Redshift. @@ -169,6 +177,8 @@ def to_sql( When there is a primary_key match during upsert, this column will change the upsert method, comparing the values of the specified column from source and target, and keeping the larger of the two. Will only work when mode = upsert. + add_new_columns + If True, it automatically adds the new DataFrame columns into the target table. Examples -------- @@ -191,6 +201,19 @@ def to_sql( con.autocommit = False try: with con.cursor() as cursor: + if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table): + redshift_columns_types = _get_rsh_columns_types( + df=df, + path=None, + index=index, + dtype=dtype, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + ) + _add_new_table_columns( + cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types + ) + created_table, created_schema = _create_table( df=df, path=None, @@ -280,6 +303,7 @@ def copy_from_files( # noqa: PLR0913 s3_additional_kwargs: dict[str, str] | None = None, precombine_key: str | None = None, column_names: list[str] | None = None, + add_new_columns: bool = False, ) -> None: """Load files from S3 to a Table on Amazon Redshift (Through COPY command). @@ -396,6 +420,8 @@ def copy_from_files( # noqa: PLR0913 larger of the two. Will only work when mode = upsert. column_names List of column names to map source data fields to the target columns. + add_new_columns + If True, it automatically adds the new DataFrame columns into the target table. Examples -------- @@ -420,6 +446,27 @@ def copy_from_files( # noqa: PLR0913 con.autocommit = False try: with con.cursor() as cursor: + if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table): + redshift_columns_types = _get_rsh_columns_types( + df=None, + path=path, + index=False, + dtype=None, + varchar_lengths_default=varchar_lengths_default, + varchar_lengths=varchar_lengths, + parquet_infer_sampling=parquet_infer_sampling, + path_suffix=path_suffix, + path_ignore_suffix=path_ignore_suffix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + data_format=data_format, + redshift_column_types=redshift_column_types, + manifest=manifest, + ) + _add_new_table_columns( + cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types + ) created_table, created_schema = _create_table( df=None, path=path, @@ -521,6 +568,7 @@ def copy( # noqa: PLR0913 max_rows_by_file: int | None = 10_000_000, precombine_key: str | None = None, use_column_names: bool = False, + add_new_columns: bool = False, ) -> None: """Load Pandas DataFrame as a Table on Amazon Redshift using parquet files on S3 as stage. @@ -628,6 +676,8 @@ def copy( # noqa: PLR0913 If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query. E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be inserted into the database columns `col1` and `col3`. + add_new_columns + If True, it automatically adds the new DataFrame columns into the target table. Examples -------- @@ -692,6 +742,7 @@ def copy( # noqa: PLR0913 sql_copy_extra_params=sql_copy_extra_params, precombine_key=precombine_key, column_names=column_names, + add_new_columns=add_new_columns, ) finally: if keep_files is False: diff --git a/tests/unit/test_redshift.py b/tests/unit/test_redshift.py index c7e99a79c..7207c5d24 100644 --- a/tests/unit/test_redshift.py +++ b/tests/unit/test_redshift.py @@ -1426,3 +1426,172 @@ def test_unload_escape_quotation_marks( keep_files=False, ) assert len(df2) == 1 + + +@pytest.mark.parametrize( + "mode,overwrite_method", + [ + ("append", ""), + ("upsert", ""), + ("overwrite", "drop"), + ("overwrite", "cascade"), + ("overwrite", "truncate"), + ("overwrite", "delete"), + ], +) +def test_copy_add_new_columns( + path: str, + redshift_table: str, + redshift_con: redshift_connector.Connection, + databases_parameters: dict[str, Any], + mode: str, + overwrite_method: str, +) -> None: + schema = "public" + df = pd.DataFrame({"foo": ["a", "b", "c"], "bar": ["c", "d", "e"]}) + copy_kwargs = { + "df": df, + "path": path, + "con": redshift_con, + "schema": schema, + "table": redshift_table, + "iam_role": databases_parameters["redshift"]["role"], + "primary_keys": ["foo"] if mode == "upsert" else None, + "overwrite_method": overwrite_method, + } + + # Create table + wr.redshift.copy(**copy_kwargs, add_new_columns=True, mode="overwrite") + copy_kwargs["mode"] = mode + + # Add new columns + df["xoo"] = ["f", "g", "h"] + df["baz"] = ["j", "k", "l"] + wr.redshift.copy(**copy_kwargs, add_new_columns=True) + + sql = f"SELECT * FROM {schema}.{redshift_table}" + if mode == "append": + sql += "\nWHERE xoo IS NOT NULL AND baz IS NOT NULL" + df2 = wr.redshift.read_sql_query(sql=sql, con=redshift_con) + df2 = df2.sort_values(by=df2.columns.to_list()) + assert df.values.tolist() == df2.values.tolist() + assert df.columns.tolist() == df2.columns.tolist() + + # Assert error when trying to add a new column without 'add_new_columns' parameter (False as default) in "append" + # or "upsert". No error are expected in ('drop', 'cascade') overwrite_method + df["abc"] = ["m", "n", "o"] + if overwrite_method in ("drop", "cascade"): + wr.redshift.copy(**copy_kwargs) + else: + with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info: + wr.redshift.copy(**copy_kwargs) + assert "ProgrammingError" == exc_info.typename + assert "unmatched number of columns" in str(exc_info.value).lower() + + +@pytest.mark.parametrize( + "mode,overwrite_method", + [ + ("append", ""), + ("upsert", ""), + ("overwrite", "drop"), + ("overwrite", "cascade"), + ("overwrite", "truncate"), + ("overwrite", "delete"), + ], +) +def test_to_sql_add_new_columns( + path: str, + redshift_table: str, + redshift_con: redshift_connector.Connection, + databases_parameters: dict[str, Any], + mode: str, + overwrite_method: str, +) -> None: + schema = "public" + df = pd.DataFrame({"foo": ["a", "b", "c"], "bar": ["c", "d", "e"]}) + to_sql_kwargs = { + "df": df, + "con": redshift_con, + "schema": schema, + "table": redshift_table, + "primary_keys": ["foo"] if mode == "upsert" else None, + "overwrite_method": overwrite_method, + } + + # Create table + wr.redshift.to_sql(**to_sql_kwargs, add_new_columns=True, mode="overwrite") + to_sql_kwargs["mode"] = mode + + # Add new columns + df["xoo"] = ["f", "g", "h"] + df["baz"] = ["j", "k", "l"] + wr.redshift.to_sql(**to_sql_kwargs, add_new_columns=True) + + sql = f"SELECT * FROM {schema}.{redshift_table}" + if mode == "append": + sql += "\nWHERE xoo IS NOT NULL AND baz IS NOT NULL" + df2 = wr.redshift.read_sql_query(sql=sql, con=redshift_con) + df2 = df2.sort_values(by=df2.columns.to_list()) + assert df.values.tolist() == df2.values.tolist() + assert df.columns.tolist() == df2.columns.tolist() + + # Assert error when trying to add a new column without 'add_new_columns' parameter (False as default) in "append" + # or "upsert". No errors expected in ('drop', 'cascade') overwrite_method + df["abc"] = ["m", "n", "o"] + if overwrite_method in ("drop", "cascade"): + wr.redshift.to_sql(**to_sql_kwargs) + else: + with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info: + wr.redshift.to_sql(**to_sql_kwargs) + assert "ProgrammingError" == exc_info.typename + assert "insert has more expressions than target columns" in str(exc_info.value).lower() + + with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info: + wr.redshift.to_sql(**to_sql_kwargs, use_column_names=True) + assert "ProgrammingError" == exc_info.typename + assert 'column "abc" of relation' in str(exc_info.value).lower() + + +def test_add_new_columns_case_sensitive( + path: str, redshift_table: str, redshift_con: redshift_connector.Connection, databases_parameters: dict[str, Any] +) -> None: + schema = "public" + df = pd.DataFrame({"foo": ["a", "b", "c"]}) + + # Create table + wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True) + + # Set enable_case_sensitive_identifier to False (default value) + with redshift_con.cursor() as cursor: + cursor.execute("SET enable_case_sensitive_identifier TO off;") + redshift_con.commit() + + df["Boo"] = ["f", "g", "h"] + wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True) + df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM {schema}.{redshift_table}", con=redshift_con) + + # Since 'enable_case_sensitive_identifier' is set to False, the column 'Boo' is automatically written as 'boo' by + # Redshift + assert df2.columns.tolist() == [x.lower() for x in df.columns] + assert "boo" in df2.columns + + # Trying to add a new column 'BOO' causes an exception because Redshift attempts to lowercase it, resulting in a + # columns mismatch between the DataFrame and the table schema + df["BOO"] = ["j", "k", "l"] + with pytest.raises(redshift_connector.error.ProgrammingError) as exc_info: + wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True) + assert "insert has more expressions than target columns" in str(exc_info.value).lower() + + # Enable enable_case_sensitive_identifier + with redshift_con.cursor() as cursor: + cursor.execute("SET enable_case_sensitive_identifier TO on;") + redshift_con.commit() + wr.redshift.to_sql(df=df, con=redshift_con, table=redshift_table, schema=schema, add_new_columns=True) + cursor.execute("RESET enable_case_sensitive_identifier;") + redshift_con.commit() + + # Ensure that the new uppercase columns have been added correctly + df2 = wr.redshift.read_sql_query(sql=f"SELECT * FROM {schema}.{redshift_table}", con=redshift_con) + expected_columns = list(sorted(df.columns.tolist() + ["boo"])) + assert expected_columns == list(sorted(df2.columns.tolist()))