Skip to content

Commit bfa18fb

Browse files
committed
updated cast_scalar_to_arr to support tuple shape for extension dtype
1 parent 0f9178e commit bfa18fb

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

pandas/core/dtypes/cast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,10 @@ def infer_dtype_from_array(arr, pandas_dtype: bool = False) -> Tuple[DtypeObj, A
758758
if pandas_dtype and is_extension_array_dtype(arr):
759759
return arr.dtype, arr
760760

761+
dtype, _ = infer_dtype_from_scalar(arr[0], pandas_dtype=True)
762+
if is_extension_array_dtype(dtype):
763+
return dtype, arr
764+
761765
elif isinstance(arr, ABCSeries):
762766
return arr.dtype, np.asarray(arr)
763767

@@ -1510,7 +1514,10 @@ def cast_scalar_to_array(shape, value, dtype: Optional[DtypeObj] = None) -> np.n
15101514
fill_value = value
15111515

15121516
if is_extension_array_dtype(dtype):
1513-
values = dtype.construct_array_type()._from_sequence([value] * shape)
1517+
if isinstance(shape, int):
1518+
shape = (shape, 1)
1519+
value = [construct_1d_arraylike_from_scalar(value, shape[0], dtype)]
1520+
values = value * shape[1]
15141521
else:
15151522
values = np.empty(shape, dtype=dtype)
15161523
values.fill(fill_value)

pandas/core/frame.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
cast_scalar_to_array,
7777
coerce_to_dtypes,
7878
find_common_type,
79+
infer_dtype_from_array,
7980
infer_dtype_from_scalar,
8081
invalidate_string_dtypes,
8182
maybe_cast_to_datetime,
@@ -528,9 +529,15 @@ def __init__(
528529
values = cast_scalar_to_array(
529530
(len(index), len(columns)), data, dtype=dtype
530531
)
531-
mgr = init_ndarray(
532-
values, index, columns, dtype=values.dtype, copy=False
533-
)
532+
if isinstance(values, list):
533+
# Case 1: values is a list of extension arrays
534+
dtype, _ = infer_dtype_from_array(values[0], pandas_dtype=True)
535+
mgr = arrays_to_mgr(values, columns, index, columns, dtype=dtype)
536+
else:
537+
# Case 2: values is a numpy array
538+
mgr = init_ndarray(
539+
values, index, columns, dtype=values.dtype, copy=False
540+
)
534541
else:
535542
raise ValueError("DataFrame constructor not properly called!")
536543

@@ -3731,6 +3738,11 @@ def reindexer(value):
37313738

37323739
# upcast
37333740
value = cast_scalar_to_array(len(self.index), value)
3741+
3742+
# if extension dtype, value will be a list of length 1
3743+
if isinstance(value, list):
3744+
value = value[0]
3745+
37343746
value = maybe_cast_to_datetime(value, infer_dtype)
37353747

37363748
# return internal types directly

pandas/tests/dtypes/cast/test_infer_dtype.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,4 @@ def test_cast_scalar_to_extension_array(obj, dtype):
214214
exp = dtype.construct_array_type()._from_sequence([obj] * shape)
215215

216216
arr = cast_scalar_to_array(shape, obj, dtype=dtype)
217-
tm.assert_extension_array_equal(arr, exp)
217+
tm.assert_extension_array_equal(arr[0], exp)

0 commit comments

Comments
 (0)