diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index c4235ba1..8b7ef18e 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -109,7 +109,7 @@ class OnewayPromotableDtypes(NamedTuple): @composite def oneway_promotable_dtypes( draw, dtypes: Sequence[DataType] -) -> SearchStrategy[OnewayPromotableDtypes]: +) -> OnewayPromotableDtypes: """Return a strategy for input dtypes that promote to result dtypes.""" d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes)) result_dtype = dh.result_type(d1, d2) @@ -127,7 +127,7 @@ class OnewayBroadcastableShapes(NamedTuple): @composite -def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]: +def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: """Return a strategy for input shapes that broadcast to result shapes.""" result_shape = draw(shapes(min_side=1)) input_shape = draw( diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1c2a24f7..94b6b0ec 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -377,7 +377,7 @@ def test_eye(n_rows, n_cols, kw): @st.composite -def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex]]: +def full_fill_values(draw) -> Union[bool, int, float, complex]: kw = draw( st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") )