Skip to content

Commit 5cb7a4d

Browse files
authored
feat(redshift): Automatically add new DataFrame columns to Redshift tables during write operation (#2948)
* feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * fix: ruff formatting * fix: ruff formatting * fix: get redshift_types only if needed * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during COPY operation * chore: code style * chore: code style * chore: code style * feat: Automatically add new columns to Redshift table during COPY operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation * feat: Automatically add new columns to Redshift table during write operation
1 parent 119dc4e commit 5cb7a4d

File tree

3 files changed

+353
-44
lines changed

3 files changed

+353
-44
lines changed

awswrangler/redshift/_utils.py

Lines changed: 132 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,27 @@ def _get_primary_keys(cursor: "redshift_connector.Cursor", schema: str, table: s
106106
return fields
107107

108108

109+
def _get_table_columns(cursor: "redshift_connector.Cursor", schema: str, table: str) -> list[str]:
110+
sql = f"SELECT column_name FROM svv_columns\n WHERE table_schema = '{schema}' AND table_name = '{table}'"
111+
_logger.debug("Executing select query:\n%s", sql)
112+
cursor.execute(sql)
113+
result: tuple[list[str]] = cursor.fetchall()
114+
columns = ["".join(lst) for lst in result]
115+
return columns
116+
117+
118+
def _add_table_columns(
119+
cursor: "redshift_connector.Cursor", schema: str, table: str, new_columns: dict[str, str]
120+
) -> None:
121+
for column_name, column_type in new_columns.items():
122+
sql = (
123+
f"ALTER TABLE {_identifier(schema)}.{_identifier(table)}"
124+
f"\nADD COLUMN {_identifier(column_name)} {column_type};"
125+
)
126+
_logger.debug("Executing alter query:\n%s", sql)
127+
cursor.execute(sql)
128+
129+
109130
def _does_table_exist(cursor: "redshift_connector.Cursor", schema: str | None, table: str) -> bool:
110131
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
111132
sql = (
@@ -128,6 +149,16 @@ def _get_paths_from_manifest(path: str, boto3_session: boto3.Session | None = No
128149
return paths
129150

130151

152+
def _get_parameter_setting(cursor: "redshift_connector.Cursor", parameter_name: str) -> str:
153+
sql = f"SHOW {parameter_name}"
154+
_logger.debug("Executing select query:\n%s", sql)
155+
cursor.execute(sql)
156+
result = cursor.fetchall()
157+
status = str(result[0][0])
158+
_logger.debug(f"{parameter_name}='{status}'")
159+
return status
160+
161+
131162
def _lock(
132163
cursor: "redshift_connector.Cursor",
133164
table_names: list[str],
@@ -267,7 +298,90 @@ def _redshift_types_from_path(
267298
return redshift_types
268299

269300

270-
def _create_table( # noqa: PLR0912,PLR0913,PLR0915
301+
def _get_rsh_columns_types(
302+
df: pd.DataFrame | None,
303+
path: str | list[str] | None,
304+
index: bool,
305+
dtype: dict[str, str] | None,
306+
varchar_lengths_default: int,
307+
varchar_lengths: dict[str, int] | None,
308+
data_format: Literal["parquet", "orc", "csv"] = "parquet",
309+
redshift_column_types: dict[str, str] | None = None,
310+
parquet_infer_sampling: float = 1.0,
311+
path_suffix: str | None = None,
312+
path_ignore_suffix: str | list[str] | None = None,
313+
manifest: bool | None = False,
314+
use_threads: bool | int = True,
315+
boto3_session: boto3.Session | None = None,
316+
s3_additional_kwargs: dict[str, str] | None = None,
317+
) -> dict[str, str]:
318+
if df is not None:
319+
redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
320+
df=df,
321+
index=index,
322+
dtype=dtype,
323+
varchar_lengths_default=varchar_lengths_default,
324+
varchar_lengths=varchar_lengths,
325+
converter_func=_data_types.pyarrow2redshift,
326+
)
327+
_logger.debug("Converted redshift types from pandas: %s", redshift_types)
328+
elif path is not None:
329+
if manifest:
330+
if not isinstance(path, str):
331+
raise TypeError(
332+
f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
333+
must be a string"""
334+
)
335+
path = _get_paths_from_manifest(
336+
path=path,
337+
boto3_session=boto3_session,
338+
)
339+
340+
if data_format in ["parquet", "orc"]:
341+
redshift_types = _redshift_types_from_path(
342+
path=path,
343+
data_format=data_format, # type: ignore[arg-type]
344+
varchar_lengths_default=varchar_lengths_default,
345+
varchar_lengths=varchar_lengths,
346+
parquet_infer_sampling=parquet_infer_sampling,
347+
path_suffix=path_suffix,
348+
path_ignore_suffix=path_ignore_suffix,
349+
use_threads=use_threads,
350+
boto3_session=boto3_session,
351+
s3_additional_kwargs=s3_additional_kwargs,
352+
)
353+
else:
354+
if redshift_column_types is None:
355+
raise ValueError(
356+
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
357+
)
358+
redshift_types = redshift_column_types
359+
else:
360+
raise ValueError("df and path are None. You MUST pass at least one.")
361+
return redshift_types
362+
363+
364+
def _add_new_table_columns(
365+
cursor: "redshift_connector.Cursor", schema: str, table: str, redshift_columns_types: dict[str, str]
366+
) -> None:
367+
# Check if Redshift is configured as case sensitive or not
368+
is_case_sensitive = False
369+
if _get_parameter_setting(cursor=cursor, parameter_name="enable_case_sensitive_identifier").lower() in [
370+
"on",
371+
"true",
372+
]:
373+
is_case_sensitive = True
374+
375+
# If it is case-insensitive, convert all the DataFrame columns names to lowercase before performing the comparison
376+
if is_case_sensitive is False:
377+
redshift_columns_types = {key.lower(): value for key, value in redshift_columns_types.items()}
378+
actual_table_columns = set(_get_table_columns(cursor=cursor, schema=schema, table=table))
379+
new_df_columns = {key: value for key, value in redshift_columns_types.items() if key not in actual_table_columns}
380+
381+
_add_table_columns(cursor=cursor, schema=schema, table=table, new_columns=new_df_columns)
382+
383+
384+
def _create_table( # noqa: PLR0913
271385
df: pd.DataFrame | None,
272386
path: str | list[str] | None,
273387
con: "redshift_connector.Connection",
@@ -336,49 +450,24 @@ def _create_table( # noqa: PLR0912,PLR0913,PLR0915
336450
return table, schema
337451
diststyle = diststyle.upper() if diststyle else "AUTO"
338452
sortstyle = sortstyle.upper() if sortstyle else "COMPOUND"
339-
if df is not None:
340-
redshift_types: dict[str, str] = _data_types.database_types_from_pandas(
341-
df=df,
342-
index=index,
343-
dtype=dtype,
344-
varchar_lengths_default=varchar_lengths_default,
345-
varchar_lengths=varchar_lengths,
346-
converter_func=_data_types.pyarrow2redshift,
347-
)
348-
_logger.debug("Converted redshift types from pandas: %s", redshift_types)
349-
elif path is not None:
350-
if manifest:
351-
if not isinstance(path, str):
352-
raise TypeError(
353-
f"""type: {type(path)} is not a valid type for 'path' when 'manifest' is set to True;
354-
must be a string"""
355-
)
356-
path = _get_paths_from_manifest(
357-
path=path,
358-
boto3_session=boto3_session,
359-
)
360453

361-
if data_format in ["parquet", "orc"]:
362-
redshift_types = _redshift_types_from_path(
363-
path=path,
364-
data_format=data_format, # type: ignore[arg-type]
365-
varchar_lengths_default=varchar_lengths_default,
366-
varchar_lengths=varchar_lengths,
367-
parquet_infer_sampling=parquet_infer_sampling,
368-
path_suffix=path_suffix,
369-
path_ignore_suffix=path_ignore_suffix,
370-
use_threads=use_threads,
371-
boto3_session=boto3_session,
372-
s3_additional_kwargs=s3_additional_kwargs,
373-
)
374-
else:
375-
if redshift_column_types is None:
376-
raise ValueError(
377-
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
378-
)
379-
redshift_types = redshift_column_types
380-
else:
381-
raise ValueError("df and path are None. You MUST pass at least one.")
454+
redshift_types = _get_rsh_columns_types(
455+
df=df,
456+
path=path,
457+
index=index,
458+
dtype=dtype,
459+
varchar_lengths_default=varchar_lengths_default,
460+
varchar_lengths=varchar_lengths,
461+
parquet_infer_sampling=parquet_infer_sampling,
462+
path_suffix=path_suffix,
463+
path_ignore_suffix=path_ignore_suffix,
464+
use_threads=use_threads,
465+
boto3_session=boto3_session,
466+
s3_additional_kwargs=s3_additional_kwargs,
467+
data_format=data_format,
468+
redshift_column_types=redshift_column_types,
469+
manifest=manifest,
470+
)
382471
_validate_parameters(
383472
redshift_types=redshift_types,
384473
diststyle=diststyle,

awswrangler/redshift/_write.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@
1313
from awswrangler._config import apply_configs
1414

1515
from ._connect import _validate_connection
16-
from ._utils import _create_table, _make_s3_auth_string, _upsert
16+
from ._utils import (
17+
_add_new_table_columns,
18+
_create_table,
19+
_does_table_exist,
20+
_get_rsh_columns_types,
21+
_make_s3_auth_string,
22+
_upsert,
23+
)
1724

1825
if TYPE_CHECKING:
1926
try:
@@ -102,6 +109,7 @@ def to_sql(
102109
chunksize: int = 200,
103110
commit_transaction: bool = True,
104111
precombine_key: str | None = None,
112+
add_new_columns: bool = False,
105113
) -> None:
106114
"""Write records stored in a DataFrame into Redshift.
107115
@@ -169,6 +177,8 @@ def to_sql(
169177
When there is a primary_key match during upsert, this column will change the upsert method,
170178
comparing the values of the specified column from source and target, and keeping the
171179
larger of the two. Will only work when mode = upsert.
180+
add_new_columns
181+
If True, it automatically adds the new DataFrame columns into the target table.
172182
173183
Examples
174184
--------
@@ -191,6 +201,19 @@ def to_sql(
191201
con.autocommit = False
192202
try:
193203
with con.cursor() as cursor:
204+
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
205+
redshift_columns_types = _get_rsh_columns_types(
206+
df=df,
207+
path=None,
208+
index=index,
209+
dtype=dtype,
210+
varchar_lengths_default=varchar_lengths_default,
211+
varchar_lengths=varchar_lengths,
212+
)
213+
_add_new_table_columns(
214+
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
215+
)
216+
194217
created_table, created_schema = _create_table(
195218
df=df,
196219
path=None,
@@ -280,6 +303,7 @@ def copy_from_files( # noqa: PLR0913
280303
s3_additional_kwargs: dict[str, str] | None = None,
281304
precombine_key: str | None = None,
282305
column_names: list[str] | None = None,
306+
add_new_columns: bool = False,
283307
) -> None:
284308
"""Load files from S3 to a Table on Amazon Redshift (Through COPY command).
285309
@@ -396,6 +420,8 @@ def copy_from_files( # noqa: PLR0913
396420
larger of the two. Will only work when mode = upsert.
397421
column_names
398422
List of column names to map source data fields to the target columns.
423+
add_new_columns
424+
If True, it automatically adds the new DataFrame columns into the target table.
399425
400426
Examples
401427
--------
@@ -420,6 +446,27 @@ def copy_from_files( # noqa: PLR0913
420446
con.autocommit = False
421447
try:
422448
with con.cursor() as cursor:
449+
if add_new_columns and _does_table_exist(cursor=cursor, schema=schema, table=table):
450+
redshift_columns_types = _get_rsh_columns_types(
451+
df=None,
452+
path=path,
453+
index=False,
454+
dtype=None,
455+
varchar_lengths_default=varchar_lengths_default,
456+
varchar_lengths=varchar_lengths,
457+
parquet_infer_sampling=parquet_infer_sampling,
458+
path_suffix=path_suffix,
459+
path_ignore_suffix=path_ignore_suffix,
460+
use_threads=use_threads,
461+
boto3_session=boto3_session,
462+
s3_additional_kwargs=s3_additional_kwargs,
463+
data_format=data_format,
464+
redshift_column_types=redshift_column_types,
465+
manifest=manifest,
466+
)
467+
_add_new_table_columns(
468+
cursor=cursor, schema=schema, table=table, redshift_columns_types=redshift_columns_types
469+
)
423470
created_table, created_schema = _create_table(
424471
df=None,
425472
path=path,
@@ -521,6 +568,7 @@ def copy( # noqa: PLR0913
521568
max_rows_by_file: int | None = 10_000_000,
522569
precombine_key: str | None = None,
523570
use_column_names: bool = False,
571+
add_new_columns: bool = False,
524572
) -> None:
525573
"""Load Pandas DataFrame as a Table on Amazon Redshift using parquet files on S3 as stage.
526574
@@ -628,6 +676,8 @@ def copy( # noqa: PLR0913
628676
If set to True, will use the column names of the DataFrame for generating the INSERT SQL Query.
629677
E.g. If the DataFrame has two columns `col1` and `col3` and `use_column_names` is True, data will only be
630678
inserted into the database columns `col1` and `col3`.
679+
add_new_columns
680+
If True, it automatically adds the new DataFrame columns into the target table.
631681
632682
Examples
633683
--------
@@ -692,6 +742,7 @@ def copy( # noqa: PLR0913
692742
sql_copy_extra_params=sql_copy_extra_params,
693743
precombine_key=precombine_key,
694744
column_names=column_names,
745+
add_new_columns=add_new_columns,
695746
)
696747
finally:
697748
if keep_files is False:

0 commit comments

Comments
 (0)