Skip to content

Commit 230c814

Browse files
author
maxim veksler
committed
check_round_trip refactoring to trap on FastParquet wrires to s3.
1 parent 452104e commit 230c814

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

pandas/tests/io/test_parquet.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,37 @@ def check_error_on_write(self, df, engine, exc):
212212
with tm.ensure_clean() as path:
213213
to_parquet(df, path, engine, compression=None)
214214

215-
def check_round_trip(self, df, engine, expected=None,
216-
write_kwargs=None, read_kwargs=None,
217-
check_names=True):
215+
def do_round_trip(self, df, path, engine_impl, expected=None,
216+
write_kwargs=None, read_kwargs=None,
217+
check_names=True):
218+
218219
if write_kwargs is None:
219-
write_kwargs = {}
220+
write_kwargs = {'compression': None}
221+
220222
if read_kwargs is None:
221223
read_kwargs = {}
222-
with tm.ensure_clean() as path:
223-
df.to_parquet(path, engine, **write_kwargs)
224-
result = read_parquet(path, engine, **read_kwargs)
225224

226-
if expected is None:
227-
expected = df
228-
tm.assert_frame_equal(result, expected, check_names=check_names)
225+
df.to_parquet(path, engine_impl, **write_kwargs)
226+
actual = read_parquet(path, engine_impl, **read_kwargs)
229227

230-
# repeat
231-
to_parquet(df, path, engine, **write_kwargs)
232-
result = pd.read_parquet(path, engine, **read_kwargs)
228+
if expected is None:
229+
expected = df
233230

234-
if expected is None:
235-
expected = df
236-
tm.assert_frame_equal(result, expected, check_names=check_names)
231+
tm.assert_frame_equal(expected, actual, check_names=check_names)
232+
233+
def check_round_trip(self, df, engine, expected=None,
234+
write_kwargs=None, read_kwargs=None,
235+
check_names=True):
236+
237+
with tm.ensure_clean() as path:
238+
self.do_round_trip(df, path, engine, expected,
239+
write_kwargs=write_kwargs, read_kwargs=read_kwargs,
240+
check_names=check_names)
241+
242+
# repeat
243+
self.do_round_trip(df, path, engine, expected,
244+
write_kwargs=write_kwargs, read_kwargs=read_kwargs,
245+
check_names=check_names)
237246

238247

239248
class TestBasic(Base):
@@ -251,7 +260,7 @@ def test_columns_dtypes(self, engine):
251260

252261
# unicode
253262
df.columns = [u'foo', u'bar']
254-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
263+
self.check_round_trip(df, engine)
255264

256265
def test_columns_dtypes_invalid(self, engine):
257266

@@ -292,7 +301,6 @@ def test_read_columns(self, engine):
292301

293302
expected = pd.DataFrame({'string': list('abc')})
294303
self.check_round_trip(df, engine, expected=expected,
295-
write_kwargs={'compression': None},
296304
read_kwargs={'columns': ['string']})
297305

298306
def test_write_index(self, engine):
@@ -304,7 +312,7 @@ def test_write_index(self, engine):
304312
pytest.skip("pyarrow is < 0.7.0")
305313

306314
df = pd.DataFrame({'A': [1, 2, 3]})
307-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
315+
self.check_round_trip(df, engine)
308316

309317
indexes = [
310318
[2, 3, 4],
@@ -315,15 +323,12 @@ def test_write_index(self, engine):
315323
# non-default index
316324
for index in indexes:
317325
df.index = index
318-
self.check_round_trip(
319-
df, engine,
320-
write_kwargs={'compression': None},
321-
check_names=check_names)
326+
self.check_round_trip(df, engine, check_names=check_names)
322327

323328
# index with meta-data
324329
df.index = [0, 1, 2]
325330
df.index.name = 'foo'
326-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
331+
self.check_round_trip(df, engine)
327332

328333
def test_write_multiindex(self, pa_ge_070):
329334
# Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
@@ -332,7 +337,7 @@ def test_write_multiindex(self, pa_ge_070):
332337
df = pd.DataFrame({'A': [1, 2, 3]})
333338
index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)])
334339
df.index = index
335-
self.check_round_trip(df, engine, write_kwargs={'compression': None})
340+
self.check_round_trip(df, engine)
336341

337342
def test_write_column_multiindex(self, engine):
338343
# column multi-index
@@ -428,13 +433,7 @@ def test_categorical_unsupported(self, pa_lt_070):
428433

429434
def test_s3_roundtrip(self, df_compat, s3_resource, pa):
430435
# GH #19134
431-
df_compat.to_parquet('s3://pandas-test/test.parquet',
432-
engine=pa, compression=None)
433-
434-
expected = df_compat
435-
actual = read_parquet('s3://pandas-test/test.parquet', engine=pa)
436-
437-
tm.assert_frame_equal(expected, actual)
436+
self.do_round_trip(df_compat, 's3://pandas-test/test.parquet', pa)
438437

439438

440439
class TestParquetFastParquet(Base):
@@ -446,7 +445,7 @@ def test_basic(self, fp, df_full):
446445
# additional supported types for fastparquet
447446
df['timedelta'] = pd.timedelta_range('1 day', periods=3)
448447

449-
self.check_round_trip(df, fp, write_kwargs={'compression': None})
448+
self.check_round_trip(df, fp)
450449

451450
@pytest.mark.skip(reason="not supported")
452451
def test_duplicate_columns(self, fp):
@@ -459,8 +458,7 @@ def test_duplicate_columns(self, fp):
459458
def test_bool_with_none(self, fp):
460459
df = pd.DataFrame({'a': [True, None, False]})
461460
expected = pd.DataFrame({'a': [1.0, np.nan, 0.0]}, dtype='float16')
462-
self.check_round_trip(df, fp, expected=expected,
463-
write_kwargs={'compression': None})
461+
self.check_round_trip(df, fp, expected=expected)
464462

465463
def test_unsupported(self, fp):
466464

@@ -476,7 +474,7 @@ def test_categorical(self, fp):
476474
if LooseVersion(fastparquet.__version__) < LooseVersion("0.1.3"):
477475
pytest.skip("CategoricalDtype not supported for older fp")
478476
df = pd.DataFrame({'a': pd.Categorical(list('abc'))})
479-
self.check_round_trip(df, fp, write_kwargs={'compression': None})
477+
self.check_round_trip(df, fp)
480478

481479
def test_datetime_tz(self, fp):
482480
# doesn't preserve tz
@@ -485,8 +483,7 @@ def test_datetime_tz(self, fp):
485483

486484
# warns on the coercion
487485
with catch_warnings(record=True):
488-
self.check_round_trip(df, fp, df.astype('datetime64[ns]'),
489-
write_kwargs={'compression': None})
486+
self.check_round_trip(df, fp, df.astype('datetime64[ns]'))
490487

491488
def test_filter_row_groups(self, fp):
492489
d = {'a': list(range(0, 3))}
@@ -497,3 +494,10 @@ def test_filter_row_groups(self, fp):
497494
result = read_parquet(path, fp, filters=[('a', '==', 0)])
498495
assert len(result) == 1
499496

497+
def test_s3_roundtrip(self, df_compat, s3_resource, fp):
498+
print(s3_resource, fp)
499+
500+
# GH #19134
501+
with pytest.raises(TypeError):
502+
self.do_round_trip(df_compat, 's3://pandas-test/test.parquet', fp)
503+

0 commit comments

Comments
 (0)