From 6a6f4cfb5c697501cd0eeb281f83237147965175 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 17 Nov 2023 16:09:20 +0000 Subject: [PATCH] Remove `st.SearchStrategy` for hinting `@st.composite` strategies composite decorated functions should be type hinted by the type of individual draws --- array_api_tests/hypothesis_helpers.py | 4 ++-- array_api_tests/test_creation_functions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index c4235ba1..8b7ef18e 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -109,7 +109,7 @@ class OnewayPromotableDtypes(NamedTuple): @composite def oneway_promotable_dtypes( draw, dtypes: Sequence[DataType] -) -> SearchStrategy[OnewayPromotableDtypes]: +) -> OnewayPromotableDtypes: """Return a strategy for input dtypes that promote to result dtypes.""" d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes)) result_dtype = dh.result_type(d1, d2) @@ -127,7 +127,7 @@ class OnewayBroadcastableShapes(NamedTuple): @composite -def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]: +def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: """Return a strategy for input shapes that broadcast to result shapes.""" result_shape = draw(shapes(min_side=1)) input_shape = draw( diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1c2a24f7..94b6b0ec 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -377,7 +377,7 @@ def test_eye(n_rows, n_cols, kw): @st.composite -def full_fill_values(draw) -> st.SearchStrategy[Union[bool, int, float, complex]]: +def full_fill_values(draw) -> Union[bool, int, float, complex]: kw = draw( st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") )