Skip to content

Commit 119dc4e

Browse files
authored
fix(athena): Enable use of dataframe type, in athena2pyarrow type (#2953)
* Enable use of dataframe type, in athena2pyarrow type * Ruff format check * Fix mypy error * Add test to verify write was successful * Fix test_save_dataframe_with_ms_units parameters * Add use_threads parameter
1 parent 78522fd commit 119dc4e

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

awswrangler/_arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _df_to_table(
119119
for col_name, col_type in dtype.items():
120120
if col_name in table.column_names:
121121
col_index = table.column_names.index(col_name)
122-
pyarrow_dtype = athena2pyarrow(col_type)
122+
pyarrow_dtype = athena2pyarrow(col_type, df.dtypes.get(col_name))
123123
field = pa.field(name=col_name, type=pyarrow_dtype)
124124
table = table.set_column(col_index, field, table.column(col_name).cast(pyarrow_dtype))
125125
_logger.debug("Casting column %s (%s) to %s (%s)", col_name, col_index, col_type, pyarrow_dtype)

awswrangler/_data_types.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import numpy as np
1313
import pandas as pd
1414
import pyarrow as pa
15-
import pyarrow.parquet
1615

1716
from awswrangler import _arrow, exceptions
1817
from awswrangler._distributed import engine
@@ -306,7 +305,7 @@ def _split_map(s: str) -> list[str]:
306305
return parts
307306

308307

309-
def athena2pyarrow(dtype: str) -> pa.DataType: # noqa: PLR0911,PLR0912
308+
def athena2pyarrow(dtype: str, df_type: str | None = None) -> pa.DataType: # noqa: PLR0911,PLR0912
310309
"""Athena to PyArrow data types conversion."""
311310
dtype = dtype.strip()
312311
if dtype.startswith(("array", "struct", "map")):
@@ -329,7 +328,16 @@ def athena2pyarrow(dtype: str) -> pa.DataType: # noqa: PLR0911,PLR0912
329328
if (dtype in ("string", "uuid")) or dtype.startswith("char") or dtype.startswith("varchar"):
330329
return pa.string()
331330
if dtype == "timestamp":
332-
return pa.timestamp(unit="ns")
331+
if df_type == "datetime64[ns]":
332+
return pa.timestamp(unit="ns")
333+
elif df_type == "datetime64[us]":
334+
return pa.timestamp(unit="us")
335+
elif df_type == "datetime64[ms]":
336+
return pa.timestamp(unit="ms")
337+
elif df_type == "datetime64[s]":
338+
return pa.timestamp(unit="s")
339+
else:
340+
return pa.timestamp(unit="ns")
333341
if dtype == "date":
334342
return pa.date32()
335343
if dtype in ("binary" or "varbinary"):
@@ -701,7 +709,7 @@ def pyarrow_schema_from_pandas(
701709
)
702710
for k, v in casts.items():
703711
if (k not in ignore) and (k in df.columns or _is_index_name(k, df.index)):
704-
columns_types[k] = athena2pyarrow(dtype=v)
712+
columns_types[k] = athena2pyarrow(dtype=v, df_type=df.dtypes.get(k))
705713
columns_types = {k: v for k, v in columns_types.items() if v is not None}
706714
_logger.debug("columns_types: %s", columns_types)
707715
return pa.schema(fields=columns_types)

tests/unit/test_s3_parquet.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,3 +1032,39 @@ def test_read_from_access_point(access_point_path_path: str) -> None:
10321032
wr.s3.to_parquet(df, path)
10331033
df_out = wr.s3.read_parquet(path)
10341034
assert df_out.shape == (3, 3)
1035+
1036+
1037+
@pytest.mark.parametrize("use_threads", [True, False, 2])
1038+
def test_save_dataframe_with_ms_units(path, glue_database, glue_table, use_threads):
1039+
df = pd.DataFrame(
1040+
{
1041+
"c0": [
1042+
"2023-01-01 00:00:00.000",
1043+
"2023-01-02 00:00:00.000",
1044+
"0800-01-01 00:00:00.000", # Out-of-bounds timestamp
1045+
"2977-09-21 00:12:43.000",
1046+
]
1047+
}
1048+
)
1049+
1050+
wr.s3.to_parquet(
1051+
df,
1052+
path,
1053+
dataset=True,
1054+
database=glue_database,
1055+
table=glue_table,
1056+
use_threads=use_threads,
1057+
)
1058+
1059+
# Saving exactly the same data twice. This ensures that even if the athena table exists, the flow of using its metadata
1060+
# to identify the schema of the data is working correctly.
1061+
wr.s3.to_parquet(
1062+
df,
1063+
path,
1064+
dataset=True,
1065+
database=glue_database,
1066+
table=glue_table,
1067+
use_threads=use_threads,
1068+
)
1069+
df_out = wr.s3.read_parquet_table(table=glue_table, database=glue_database)
1070+
assert df_out.shape == (8, 1)

0 commit comments

Comments
 (0)