Skip to content

Commit 508cdea

Browse files
committed
address some comments
1 parent 8e598fe commit 508cdea

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

db_dtypes/json.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,10 @@
1919

2020
import numpy as np
2121
import pandas as pd
22-
from pandas.core.arrays.arrow.array import ArrowExtensionArray
23-
from pandas.core.arrays.masked import BaseMaskedArray
24-
from pandas.core.dtypes.common import is_dict_like, is_integer, is_list_like, is_scalar
25-
from pandas.core.dtypes.dtypes import ExtensionDtype
26-
from pandas.core.indexers import check_array_indexer, unpack_tuple_and_ellipses
22+
import pandas.arrays as arrays
23+
import pandas.core.dtypes.common as common
24+
import pandas.core.indexers as indexers
2725
import pyarrow as pa
28-
import pyarrow.compute as pc
2926

3027

3128
@pd.api.extensions.register_extension_dtype
@@ -63,78 +60,81 @@ def __from_arrow__(array: typing.Union[pa.Array, pa.ChunkedArray]) -> JSONArray:
6360
return JSONArray(array)
6461

6562

66-
class JSONArray(ArrowExtensionArray):
63+
class JSONArray(arrays.ArrowExtensionArray):
6764
"""Extension array that handles BigQuery JSON data, leveraging a string-based
6865
pyarrow array for storage. It enables seamless conversion to JSON objects when
6966
accessing individual elements."""
7067

7168
_dtype = JSONDtype()
7269

7370
def __init__(self, values, dtype=None, copy=False) -> None:
74-
if isinstance(values, (pa.Array, pa.ChunkedArray)) and pa.types.is_string(
75-
values.type
76-
):
77-
values = pc.cast(values, pa.large_string())
78-
79-
super().__init__(values)
8071
self._dtype = JSONDtype()
81-
82-
if not pa.types.is_large_string(self._pa_array.type) and not (
83-
pa.types.is_dictionary(self._pa_array.type)
84-
and pa.types.is_large_string(self._pa_array.type.value_type)
85-
):
86-
raise ValueError(
87-
"ArrowStringArray requires a PyArrow (chunked) array of "
88-
"large_string type"
89-
)
72+
if isinstance(values, pa.Array):
73+
self._pa_array = pa.chunked_array([values])
74+
elif isinstance(values, pa.ChunkedArray):
75+
self._pa_array = values
76+
else:
77+
raise ValueError(f"Unsupported type '{type(values)}' for JSONArray")
9078

9179
@classmethod
9280
def _box_pa(
9381
cls, value, pa_type: pa.DataType | None = None
9482
) -> pa.Array | pa.ChunkedArray | pa.Scalar:
9583
"""Box value into a pyarrow Array, ChunkedArray or Scalar."""
9684
if isinstance(value, pa.Scalar) or not (
97-
is_list_like(value) and not is_dict_like(value)
85+
common.is_list_like(value) and not common.is_dict_like(value)
9886
):
9987
return cls._box_pa_scalar(value, pa_type)
10088
return cls._box_pa_array(value, pa_type)
10189

10290
@classmethod
10391
def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
10492
"""Box value into a pyarrow Scalar."""
105-
value = JSONArray._seralizate_json(value)
106-
pa_scalar = super()._box_pa_scalar(value, pa_type)
107-
if pa.types.is_string(pa_scalar.type) and pa_type is None:
108-
pa_scalar = pc.cast(pa_scalar, pa.large_string())
93+
if isinstance(value, pa.Scalar):
94+
pa_scalar = value
95+
if pd.isna(value):
96+
pa_scalar = pa.scalar(None, type=pa_type)
97+
else:
98+
value = JSONArray._serialize_json(value)
99+
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
100+
101+
if pa_type is not None and pa_scalar.type != pa_type:
102+
pa_scalar = pa_scalar.cast(pa_type)
109103
return pa_scalar
110104

111105
@classmethod
112106
def _box_pa_array(
113107
cls, value, pa_type: pa.DataType | None = None, copy: bool = False
114108
) -> pa.Array | pa.ChunkedArray:
115109
"""Box value into a pyarrow Array or ChunkedArray."""
116-
if (
117-
not isinstance(value, cls)
118-
and not isinstance(value, (pa.Array, pa.ChunkedArray))
119-
and not isinstance(value, BaseMaskedArray)
120-
):
121-
value = [JSONArray._seralizate_json(x) for x in value]
122-
pa_array = super()._box_pa_array(value, pa_type)
123-
if pa.types.is_string(pa_array.type) and pa_type is None:
124-
pa_array = pc.cast(pa_array, pa.large_string())
110+
if isinstance(value, cls):
111+
pa_array = value._pa_array
112+
elif isinstance(value, (pa.Array, pa.ChunkedArray)):
113+
pa_array = value
114+
else:
115+
try:
116+
value = [JSONArray._serialize_json(x) for x in value]
117+
pa_array = pa.array(value, type=pa_type, from_pandas=True)
118+
except (pa.ArrowInvalid, pa.ArrowTypeError):
119+
# GH50430: let pyarrow infer type, then cast
120+
pa_array = pa.array(value, from_pandas=True)
121+
122+
if pa_type is not None and pa_array.type != pa_type:
123+
pa_array = pa_array.cast(pa_type)
124+
125125
return pa_array
126126

127127
@classmethod
128128
def _from_sequence(cls, scalars, *, dtype=None, copy=False):
129129
"""Construct a new ExtensionArray from a sequence of scalars."""
130130
result = []
131131
for scalar in scalars:
132-
result.append(JSONArray._seralizate_json(scalar))
132+
result.append(JSONArray._serialize_json(scalar))
133133
return cls(pa.array(result, type=pa.large_string(), from_pandas=True))
134134

135135
@classmethod
136136
def _from_sequence_of_strings(
137-
cls, strings, *, dtype: ExtensionDtype, copy: bool = False
137+
cls, strings, *, dtype, copy: bool = False
138138
) -> JSONArray:
139139
"""Construct a new ExtensionArray from a sequence of strings."""
140140
return cls._from_sequence(strings, dtype=dtype, copy=copy)
@@ -152,7 +152,7 @@ def _from_factorized(cls, values, original):
152152
return cls._from_sequence(values, dtype=original.dtype)
153153

154154
@staticmethod
155-
def _seralizate_json(value):
155+
def _serialize_json(value):
156156
"""A static method that converts a JSON value into a string representation."""
157157
if isinstance(value, str) or pd.isna(value):
158158
return value
@@ -176,19 +176,19 @@ def dtype(self) -> JSONDtype:
176176

177177
def __contains__(self, key) -> bool:
178178
"""Return for `item in self`."""
179-
return super().__contains__(JSONArray._seralizate_json(key))
179+
return super().__contains__(JSONArray._serialize_json(key))
180180

181181
def insert(self, loc: int, item) -> JSONArray:
182182
"""
183183
Make new ExtensionArray inserting new item at location. Follows Python
184184
list.append semantics for negative values.
185185
"""
186-
val = JSONArray._seralizate_json(item)
186+
val = JSONArray._serialize_json(item)
187187
return super().insert(loc, val)
188188

189189
def __getitem__(self, item):
190190
"""Select a subset of self."""
191-
item = check_array_indexer(self, item)
191+
item = indexers.check_array_indexer(self, item)
192192

193193
if isinstance(item, np.ndarray):
194194
if not len(item):
@@ -203,9 +203,9 @@ def __getitem__(self, item):
203203
"boolean arrays are valid indices."
204204
)
205205
elif isinstance(item, tuple):
206-
item = unpack_tuple_and_ellipses(item)
206+
item = indexers.unpack_tuple_and_ellipses(item)
207207

208-
if is_scalar(item) and not is_integer(item):
208+
if common.is_scalar(item) and not common.is_integer(item):
209209
# e.g. "foo" or 2.5
210210
# exception message copied from numpy
211211
raise IndexError(

0 commit comments

Comments
 (0)