diff --git a/tests/core/method-class/test_method.py b/tests/core/method-class/test_method.py index 0d82a6b602..d9671ab24c 100644 --- a/tests/core/method-class/test_method.py +++ b/tests/core/method-class/test_method.py @@ -17,7 +17,6 @@ ) from web3.method import ( Method, - _apply_request_formatters, default_root_munger, ) from web3.module import ( @@ -61,11 +60,7 @@ def test_get_formatters_default_formatter_for_falsy_config(): default_result_formatters = method.result_formatters( method.method_selector_fn(), "some module" ) - assert _apply_request_formatters(["a", "b", "c"], default_request_formatters) == ( - "a", - "b", - "c", - ) + assert default_request_formatters(["a", "b", "c"]) == ("a", "b", "c") assert apply_result_formatters(default_result_formatters, ["a", "b", "c"]) == [ "a", "b", @@ -81,7 +76,7 @@ def test_get_formatters_non_falsy_config_retrieval(): method_name = method.method_selector_fn() first_formatter = (method.request_formatters(method_name).first,) all_other_formatters = method.request_formatters(method_name).funcs - assert len(first_formatter + all_other_formatters) == 2 + assert len(first_formatter + all_other_formatters) == 3 assert (method.request_formatters("eth_getBalance").first,) == first_formatter diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index d66852ed9d..7b475e3238 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -172,12 +172,9 @@ def type_aware_apply_formatters_to_dict( if isinstance(value, BaseModel): value = value.model_dump(by_alias=True) - formatted_dict: Dict[str, Any] = apply_formatters_to_dict(formatters, dict(value)) - return ( - AttributeDict.recursive(formatted_dict) - if is_attrdict(value) - else formatted_dict - ) + formatted: Dict[str, Any] + formatted = apply_formatters_to_dict(formatters, value) # type: ignore [arg-type] + return AttributeDict.recursive(formatted) if is_attrdict(value) else formatted def type_aware_apply_formatters_to_dict_keys_and_values( @@ -1115,7 +1112,7 @@ def get_request_formatters(method_name: RPCEndpoint) -> Callable[[RPCResponse], PYTHONIC_REQUEST_FORMATTERS, ) formatters = combine_formatters(request_formatter_maps, method_name) - return compose(*formatters) + return compose(tuple, *formatters) def raise_block_not_found(params: Tuple[BlockIdentifier, bool]) -> NoReturn: diff --git a/web3/manager.py b/web3/manager.py index 5f52edc63b..4e95293094 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -6,7 +6,6 @@ AsyncGenerator, Callable, Coroutine, - Dict, List, Optional, Sequence, @@ -48,6 +47,7 @@ ) from web3.method import ( Method, + ResponseFormatters, ) from web3.middleware import ( AttributeDictMiddleware, @@ -360,9 +360,7 @@ async def socket_request( self, method: RPCEndpoint, params: Any, - response_formatters: Optional[ - Tuple[Dict[str, Callable[..., Any]], Callable[..., Any], Callable[..., Any]] - ] = None, + response_formatters: Optional[ResponseFormatters] = None, ) -> RPCResponse: provider = cast(PersistentConnectionProvider, self._provider) self.logger.debug( diff --git a/web3/method.py b/web3/method.py index b3794e46bd..3e72bf25d6 100644 --- a/web3/method.py +++ b/web3/method.py @@ -3,24 +3,15 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, List, Optional, Sequence, Tuple, Type, - Union, ) import warnings -from eth_utils.curried import ( - to_tuple, -) -from eth_utils.toolz import ( - pipe, -) - from web3._utils.batching import ( RPC_METHODS_UNSUPPORTED_DURING_BATCH, ) @@ -40,6 +31,8 @@ Web3ValueError, ) from web3.types import ( + Formatter, + ResponseFormatter, RPCEndpoint, TFunc, TReturn, @@ -54,16 +47,11 @@ Munger = Callable[..., Any] - - -@to_tuple -def _apply_request_formatters( - params: Any, request_formatters: Dict[RPCEndpoint, Callable[..., TReturn]] -) -> Tuple[Any, ...]: - if request_formatters: - formatted_params = pipe(params, request_formatters) - return formatted_params - return params +RequestArgs = Tuple[RPCEndpoint, Sequence[Any]] +ResponseFormatters = Tuple[ + ResponseFormatter[Any], ResponseFormatter[Any], ResponseFormatter[Any] +] +RequestAndFormatters = Tuple[RequestArgs, ResponseFormatters] def _set_mungers( @@ -136,9 +124,13 @@ def __init__( self, json_rpc_method: Optional[RPCEndpoint] = None, mungers: Optional[Sequence[Munger]] = None, - request_formatters: Optional[Callable[..., TReturn]] = None, - result_formatters: Optional[Callable[..., TReturn]] = None, - null_result_formatters: Optional[Callable[..., TReturn]] = None, + request_formatters: Optional[Callable[[RPCEndpoint], Formatter[TReturn]]] = None, + result_formatters: Optional[ + Callable[[RPCEndpoint, "Module"], ResponseFormatter[TReturn]] + ] = None, + null_result_formatters: Optional[ + Callable[[RPCEndpoint], ResponseFormatter[TReturn]] + ] = None, method_choice_depends_on_args: Optional[Callable[..., RPCEndpoint]] = None, is_property: bool = False, ): @@ -166,13 +158,13 @@ def __get__( ) provider = module.w3.provider - if hasattr(provider, "_is_batching") and provider._is_batching: + if getattr(provider, "_is_batching", False): if self.json_rpc_method in RPC_METHODS_UNSUPPORTED_DURING_BATCH: raise MethodNotSupported( f"Method `{self.json_rpc_method}` is not supported within a batch " "request." ) - return module.retrieve_request_information(self) + return module.retrieve_request_information(self) # type: ignore [return-value] else: return module.retrieve_caller_fn(self) @@ -203,14 +195,7 @@ def input_munger(self, module: "Module", args: Any, kwargs: Any) -> List[Any]: def process_params( self, module: "Module", *args: Any, **kwargs: Any - ) -> Tuple[ - Tuple[Union[RPCEndpoint, Callable[..., RPCEndpoint]], Tuple[RPCEndpoint, ...]], - Tuple[ - Union[TReturn, Dict[str, Callable[..., Any]]], - Callable[..., Any], - Union[TReturn, Callable[..., Any]], - ], - ]: + ) -> RequestAndFormatters: params = self.input_munger(module, args, kwargs) if self.method_choice_depends_on_args: @@ -233,10 +218,12 @@ def process_params( get_error_formatters(method), self.null_result_formatters(method), ) - request = ( - method, - _apply_request_formatters(params, self.request_formatters(method)), - ) + + if request_formatters := self.request_formatters(method): + params = request_formatters(params) + + request = method, params + return request, response_formatters diff --git a/web3/module.py b/web3/module.py index 0f9d85d7bd..b259d7402b 100644 --- a/web3/module.py +++ b/web3/module.py @@ -5,11 +5,9 @@ Coroutine, Dict, Optional, - Sequence, - Tuple, TypeVar, Union, - cast, + overload, ) from eth_abi.codec import ( @@ -27,14 +25,15 @@ ) from web3.method import ( Method, + RequestAndFormatters, ) from web3.providers.persistent import ( PersistentConnectionProvider, ) from web3.types import ( FormattedEthSubscriptionResponse, - RPCEndpoint, RPCResponse, + TFunc, ) if TYPE_CHECKING: @@ -58,34 +57,50 @@ def apply_result_formatters( TReturn = TypeVar("TReturn") +@overload +def retrieve_request_information_for_batching( + w3: "AsyncWeb3", + module: "Module", + method: Method[Callable[..., Any]], +) -> Callable[..., Coroutine[Any, Any, RequestAndFormatters]]: + ... + + +@overload +def retrieve_request_information_for_batching( + w3: "Web3", + module: "Module", + method: Method[Callable[..., Any]], +) -> Callable[..., RequestAndFormatters]: + ... + + @curry def retrieve_request_information_for_batching( w3: Union["AsyncWeb3", "Web3"], module: "Module", method: Method[Callable[..., Any]], ) -> Union[ - Callable[..., Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]], - Callable[..., Coroutine[Any, Any, Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]]], + Callable[..., RequestAndFormatters], + Callable[..., Coroutine[Any, Any, RequestAndFormatters]], ]: async def async_inner( *args: Any, **kwargs: Any - ) -> Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]: + ) -> RequestAndFormatters: (method_str, params), response_formatters = method.process_params( module, *args, **kwargs ) if isinstance(w3.provider, PersistentConnectionProvider): w3.provider._request_processor.cache_request_information( - None, cast(RPCEndpoint, method_str), params, response_formatters + None, method_str, params, response_formatters ) - return (cast(RPCEndpoint, method_str), params), response_formatters + return (method_str, params), response_formatters - def inner( - *args: Any, **kwargs: Any - ) -> Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]: + def inner(*args: Any, **kwargs: Any) -> RequestAndFormatters: (method_str, params), response_formatters = method.process_params( module, *args, **kwargs ) - return (cast(RPCEndpoint, method_str), params), response_formatters + return (method_str, params), response_formatters return async_inner if module.is_async else inner @@ -142,7 +157,7 @@ async def caller( if isinstance(async_w3.provider, PersistentConnectionProvider): return await async_w3.manager.socket_request( - cast(RPCEndpoint, method_str), + method_str, params, response_formatters=response_formatters, ) @@ -167,13 +182,23 @@ async def caller( class Module: is_async = False + retrieve_request_information: Callable[ + [Method[TFunc]], + Union[ + Callable[..., RequestAndFormatters], + Callable[..., Coroutine[Any, Any, RequestAndFormatters]], + ], + ] + def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None: if self.is_async: self.retrieve_caller_fn = retrieve_async_method_call_fn(w3, self) else: self.retrieve_caller_fn = retrieve_blocking_method_call_fn(w3, self) - self.retrieve_request_information = retrieve_request_information_for_batching( - w3, self + self.retrieve_request_information = ( + retrieve_request_information_for_batching( # type: ignore [call-overload] + w3, self + ) ) self.w3 = w3 diff --git a/web3/types.py b/web3/types.py index 8a1f67ad44..fb8a17ec07 100644 --- a/web3/types.py +++ b/web3/types.py @@ -69,7 +69,10 @@ Timestamp = NewType("Timestamp", int) Wei = NewType("Wei", int) Gwei = NewType("Gwei", int) -Formatters = Dict[RPCEndpoint, Callable[..., Any]] +Formatter = Callable[..., TReturn] +Formatters = Dict[RPCEndpoint, Formatter[Any]] +ResponseFormatter = Callable[["RPCResponse"], TReturn] +ResponseFormatters = Dict[RPCEndpoint, ResponseFormatter[Any]] class AccessListEntry(TypedDict): @@ -335,9 +338,9 @@ class CreateAccessListResponse(TypedDict): class FormattersDict(TypedDict, total=False): - error_formatters: Optional[Formatters] - request_formatters: Optional[Formatters] - result_formatters: Optional[Formatters] + error_formatters: Optional[ResponseFormatters] + request_formatters: Optional[ResponseFormatters] + result_formatters: Optional[ResponseFormatters] class FilterParams(TypedDict, total=False):