diff --git a/odps/_version.py b/odps/_version.py index bfcc546..52f95f7 100644 --- a/odps/_version.py +++ b/odps/_version.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -version_info = (0, 12, 3, "rc1") +version_info = (0, 12, 3) _num_index = max(idx if isinstance(v, int) else 0 for idx, v in enumerate(version_info)) __version__ = ".".join(map(str, version_info[: _num_index + 1])) + "".join( version_info[_num_index + 1 :] diff --git a/odps/df/backends/pd/types.py b/odps/df/backends/pd/types.py index b1b70ad..987a3ec 100644 --- a/odps/df/backends/pd/types.py +++ b/odps/df/backends/pd/types.py @@ -29,8 +29,9 @@ pd = None from ... import types -from ....models import TableSchema +from .... import types as odps_types from ....compat import six +from ....models import TableSchema _np_to_df_types = dict() _df_to_np_types = dict() @@ -100,7 +101,10 @@ def np_type_to_df_type(dtype, arr=None, unknown_as_string=False, name=None): raise TypeError('Unknown dtype: %s' % dtype) -def pd_to_df_schema(pd_df, unknown_as_string=False, as_type=None): +def pd_to_df_schema(pd_df, unknown_as_string=False, as_type=None, type_mapping=None): + from ..odpssql.types import odps_type_to_df_type + + type_mapping = type_mapping or {} if pd_df.index.name is not None: pd_df.reset_index(inplace=True) @@ -113,6 +117,14 @@ def pd_to_df_schema(pd_df, unknown_as_string=False, as_type=None): if as_type and names[i] in as_type: df_types.append(as_type[names[i]]) continue + if names[i] in type_mapping: + try: + odps_type = odps_types.validate_data_type(type_mapping[names[i]]) + df_type = odps_type_to_df_type(odps_type) + except: + df_type = types.validate_data_type(type_mapping[names[i]]) + df_types.append(df_type) + continue df_types.append(np_type_to_df_type(dtypes.iloc[i], arr, unknown_as_string=unknown_as_string, name=names[i])) diff --git a/odps/models/tableio.py b/odps/models/tableio.py index 3c2e2e2..b465666 100644 --- a/odps/models/tableio.py +++ b/odps/models/tableio.py @@ -841,7 +841,11 @@ def _resolve_schema( data_schema = arrow_schema_to_odps_schema(records_list.schema) elif cls._is_pd_df(records_list): data_schema = df_schema_to_odps_schema( - pd_to_df_schema(records_list, unknown_as_string=unknown_as_string) + pd_to_df_schema( + records_list, + unknown_as_string=unknown_as_string, + type_mapping=type_mapping, + ) ) elif isinstance(records_list, list) and odps_types.is_record( records_list[0] diff --git a/odps/models/tests/test_tableio.py b/odps/models/tests/test_tableio.py index 7b166c5..3d39c95 100644 --- a/odps/models/tests/test_tableio.py +++ b/odps/models/tests/test_tableio.py @@ -824,20 +824,63 @@ def test_write_pandas_with_dynamic_parts(odps, use_arrow): @pyarrow_case @pandas_case @odps2_typed_case -def test_write_pandas_with_complex_type_and_mapping(odps): +def test_write_pandas_with_arrow_complex_type(odps): if Version(pa.__version__) < Version("1.0.0"): pytest.skip("casting nested type is not supported in arrow < 1.0.0") - test_table_name = tn("pyodps_t_tmp_write_pd_complex_type") + test_table_name = tn("pyodps_t_tmp_write_arrow_complex_type") odps.delete_table(test_table_name, if_exists=True) - table = odps.create_table( - test_table_name, - "idx string, list_data array, " - "list_struct_data array>, " - "map_data map", - table_properties={"columnar.nested.type": "true"}, - lifecycle=1, + data = pd.DataFrame( + [ + [ + "05ac09c4", + [134, 256], + [None, {"name": "col1", "val": 134}], + ], + ["cfae9054", [5431], [{"name": "col2", "val": 2345}]], + [ + "6029501d", + [145, None, 561], + [{"name": "ddd", "val": 2341}, {"name": None, "val": None}], + ], + [ + "c653e520", + [7412, 234], + [None, {"name": "uvw", "val": None}], + ], + ["59caed0d", [295, 1674], None], + ], + columns=["idx", "list_data", "list_struct_data"], ) + arrow_data = pa.Table.from_pandas(data) + try: + table_kwargs = { + "table_properties": {"columnar.nested.type": "true"}, + } + odps.write_table( + test_table_name, + arrow_data, + create_table=True, + lifecycle=1, + table_kwargs=table_kwargs, + ) + table = odps.get_table(test_table_name) + pd.testing.assert_frame_equal( + data.sort_values("idx").reset_index(drop=True), + table.to_pandas().sort_values("idx").reset_index(drop=True), + ) + finally: + odps.delete_table(test_table_name, if_exists=True) + + +@pyarrow_case +@pandas_case +@odps2_typed_case +def test_write_pandas_with_complex_type_and_mapping(odps): + if Version(pa.__version__) < Version("1.0.0"): + pytest.skip("casting nested type is not supported in arrow < 1.0.0") + test_table_name = tn("pyodps_t_tmp_write_pd_complex_type") + odps.delete_table(test_table_name, if_exists=True) data = pd.DataFrame( [ @@ -870,19 +913,24 @@ def test_write_pandas_with_complex_type_and_mapping(odps): "list_struct_data": "array>", "map_data": "map", } + table_kwargs = { + "table_properties": {"columnar.nested.type": "true"}, + } odps.write_table( test_table_name, data, type_mapping=type_mapping, create_table=True, lifecycle=1, + table_kwargs=table_kwargs, ) + table = odps.get_table(test_table_name) pd.testing.assert_frame_equal( data.sort_values("idx").reset_index(drop=True), table.to_pandas().sort_values("idx").reset_index(drop=True), ) finally: - table.drop() + odps.delete_table(test_table_name, if_exists=True) @pyarrow_case diff --git a/odps/tunnel/io/types.py b/odps/tunnel/io/types.py index b1669bb..7fe5715 100644 --- a/odps/tunnel/io/types.py +++ b/odps/tunnel/io/types.py @@ -125,7 +125,7 @@ def arrow_type_to_odps_type(arrow_type): elif isinstance(arrow_type, pa.StructType): fields = [ (arrow_type[idx].name, arrow_type_to_odps_type(arrow_type[idx].type)) - for idx in arrow_type.num_fields + for idx in range(arrow_type.num_fields) ] col_type = types.Struct(fields) else: