@@ -282,6 +282,77 @@ def test_array_interface(self):
282
282
)
283
283
tm .assert_numpy_array_equal (result , expected )
284
284
285
+ @pytest .mark .parametrize ("index" , [True , False ])
286
+ def test_searchsorted_different_tz (self , index ):
287
+ data = np .arange (10 , dtype = "i8" ) * 24 * 3600 * 10 ** 9
288
+ arr = DatetimeArray (data , freq = "D" ).tz_localize ("Asia/Tokyo" )
289
+ if index :
290
+ arr = pd .Index (arr )
291
+
292
+ expected = arr .searchsorted (arr [2 ])
293
+ result = arr .searchsorted (arr [2 ].tz_convert ("UTC" ))
294
+ assert result == expected
295
+
296
+ expected = arr .searchsorted (arr [2 :6 ])
297
+ result = arr .searchsorted (arr [2 :6 ].tz_convert ("UTC" ))
298
+ tm .assert_equal (result , expected )
299
+
300
+ @pytest .mark .parametrize ("index" , [True , False ])
301
+ def test_searchsorted_tzawareness_compat (self , index ):
302
+ data = np .arange (10 , dtype = "i8" ) * 24 * 3600 * 10 ** 9
303
+ arr = DatetimeArray (data , freq = "D" )
304
+ if index :
305
+ arr = pd .Index (arr )
306
+
307
+ mismatch = arr .tz_localize ("Asia/Tokyo" )
308
+
309
+ msg = "Cannot compare tz-naive and tz-aware datetime-like objects"
310
+ with pytest .raises (TypeError , match = msg ):
311
+ arr .searchsorted (mismatch [0 ])
312
+ with pytest .raises (TypeError , match = msg ):
313
+ arr .searchsorted (mismatch )
314
+
315
+ with pytest .raises (TypeError , match = msg ):
316
+ mismatch .searchsorted (arr [0 ])
317
+ with pytest .raises (TypeError , match = msg ):
318
+ mismatch .searchsorted (arr )
319
+
320
+ @pytest .mark .parametrize (
321
+ "other" ,
322
+ [
323
+ 1 ,
324
+ np .int64 (1 ),
325
+ 1.0 ,
326
+ np .timedelta64 ("NaT" ),
327
+ pd .Timedelta (days = 2 ),
328
+ "invalid" ,
329
+ np .arange (10 , dtype = "i8" ) * 24 * 3600 * 10 ** 9 ,
330
+ np .arange (10 ).view ("timedelta64[ns]" ) * 24 * 3600 * 10 ** 9 ,
331
+ pd .Timestamp .now ().to_period ("D" ),
332
+ ],
333
+ )
334
+ @pytest .mark .parametrize (
335
+ "index" ,
336
+ [
337
+ True ,
338
+ pytest .param (
339
+ False ,
340
+ marks = pytest .mark .xfail (
341
+ reason = "Raises ValueError instead of TypeError" , raises = ValueError
342
+ ),
343
+ ),
344
+ ],
345
+ )
346
+ def test_searchsorted_invalid_types (self , other , index ):
347
+ data = np .arange (10 , dtype = "i8" ) * 24 * 3600 * 10 ** 9
348
+ arr = DatetimeArray (data , freq = "D" )
349
+ if index :
350
+ arr = pd .Index (arr )
351
+
352
+ msg = "searchsorted requires compatible dtype or scalar"
353
+ with pytest .raises (TypeError , match = msg ):
354
+ arr .searchsorted (other )
355
+
285
356
286
357
class TestSequenceToDT64NS :
287
358
def test_tz_dtype_mismatch_raises (self ):
0 commit comments