Skip to content

chore: refactor out _apply_request_formatters #3670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions tests/core/method-class/test_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)
from web3.method import (
Method,
_apply_request_formatters,
default_root_munger,
)
from web3.module import (
Expand Down Expand Up @@ -61,11 +60,7 @@ def test_get_formatters_default_formatter_for_falsy_config():
default_result_formatters = method.result_formatters(
method.method_selector_fn(), "some module"
)
assert _apply_request_formatters(["a", "b", "c"], default_request_formatters) == (
"a",
"b",
"c",
)
assert default_request_formatters(["a", "b", "c"]) == ("a", "b", "c")
assert apply_result_formatters(default_result_formatters, ["a", "b", "c"]) == [
"a",
"b",
Expand All @@ -81,7 +76,7 @@ def test_get_formatters_non_falsy_config_retrieval():
method_name = method.method_selector_fn()
first_formatter = (method.request_formatters(method_name).first,)
all_other_formatters = method.request_formatters(method_name).funcs
assert len(first_formatter + all_other_formatters) == 2
assert len(first_formatter + all_other_formatters) == 3
assert (method.request_formatters("eth_getBalance").first,) == first_formatter


Expand Down
11 changes: 4 additions & 7 deletions web3/_utils/method_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,9 @@ def type_aware_apply_formatters_to_dict(
if isinstance(value, BaseModel):
value = value.model_dump(by_alias=True)

formatted_dict: Dict[str, Any] = apply_formatters_to_dict(formatters, dict(value))
return (
AttributeDict.recursive(formatted_dict)
if is_attrdict(value)
else formatted_dict
)
formatted: Dict[str, Any]
formatted = apply_formatters_to_dict(formatters, value) # type: ignore [arg-type]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to move this due to line len err

return AttributeDict.recursive(formatted) if is_attrdict(value) else formatted


def type_aware_apply_formatters_to_dict_keys_and_values(
Expand Down Expand Up @@ -1115,7 +1112,7 @@ def get_request_formatters(method_name: RPCEndpoint) -> Callable[[RPCResponse],
PYTHONIC_REQUEST_FORMATTERS,
)
formatters = combine_formatters(request_formatter_maps, method_name)
return compose(*formatters)
return compose(tuple, *formatters)


def raise_block_not_found(params: Tuple[BlockIdentifier, bool]) -> NoReturn:
Expand Down
6 changes: 2 additions & 4 deletions web3/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
AsyncGenerator,
Callable,
Coroutine,
Dict,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -48,6 +47,7 @@
)
from web3.method import (
Method,
ResponseFormatters,
)
from web3.middleware import (
AttributeDictMiddleware,
Expand Down Expand Up @@ -360,9 +360,7 @@ async def socket_request(
self,
method: RPCEndpoint,
params: Any,
response_formatters: Optional[
Tuple[Dict[str, Callable[..., Any]], Callable[..., Any], Callable[..., Any]]
] = None,
response_formatters: Optional[ResponseFormatters] = None,
) -> RPCResponse:
provider = cast(PersistentConnectionProvider, self._provider)
self.logger.debug(
Expand Down
59 changes: 23 additions & 36 deletions web3/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import warnings

from eth_utils.curried import (
to_tuple,
)
from eth_utils.toolz import (
pipe,
)

from web3._utils.batching import (
RPC_METHODS_UNSUPPORTED_DURING_BATCH,
)
Expand All @@ -40,6 +31,8 @@
Web3ValueError,
)
from web3.types import (
Formatter,
ResponseFormatter,
RPCEndpoint,
TFunc,
TReturn,
Expand All @@ -54,16 +47,11 @@


Munger = Callable[..., Any]


@to_tuple
def _apply_request_formatters(
params: Any, request_formatters: Dict[RPCEndpoint, Callable[..., TReturn]]
) -> Tuple[Any, ...]:
if request_formatters:
formatted_params = pipe(params, request_formatters)
return formatted_params
return params
RequestArgs = Tuple[RPCEndpoint, Sequence[Any]]
ResponseFormatters = Tuple[
ResponseFormatter[Any], ResponseFormatter[Any], ResponseFormatter[Any]
]
RequestAndFormatters = Tuple[RequestArgs, ResponseFormatters]


def _set_mungers(
Expand Down Expand Up @@ -136,9 +124,13 @@ def __init__(
self,
json_rpc_method: Optional[RPCEndpoint] = None,
mungers: Optional[Sequence[Munger]] = None,
request_formatters: Optional[Callable[..., TReturn]] = None,
result_formatters: Optional[Callable[..., TReturn]] = None,
null_result_formatters: Optional[Callable[..., TReturn]] = None,
request_formatters: Optional[Callable[[RPCEndpoint], Formatter[TReturn]]] = None,
result_formatters: Optional[
Callable[[RPCEndpoint, "Module"], ResponseFormatter[TReturn]]
] = None,
null_result_formatters: Optional[
Callable[[RPCEndpoint], ResponseFormatter[TReturn]]
] = None,
method_choice_depends_on_args: Optional[Callable[..., RPCEndpoint]] = None,
is_property: bool = False,
):
Expand Down Expand Up @@ -166,13 +158,13 @@ def __get__(
)

provider = module.w3.provider
if hasattr(provider, "_is_batching") and provider._is_batching:
if getattr(provider, "_is_batching", False):
if self.json_rpc_method in RPC_METHODS_UNSUPPORTED_DURING_BATCH:
raise MethodNotSupported(
f"Method `{self.json_rpc_method}` is not supported within a batch "
"request."
)
return module.retrieve_request_information(self)
return module.retrieve_request_information(self) # type: ignore [return-value]
Copy link
Contributor Author

@BobTheBuidler BobTheBuidler May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really sure how to deal with the line len err here

else:
return module.retrieve_caller_fn(self)

Expand Down Expand Up @@ -203,14 +195,7 @@ def input_munger(self, module: "Module", args: Any, kwargs: Any) -> List[Any]:

def process_params(
self, module: "Module", *args: Any, **kwargs: Any
) -> Tuple[
Tuple[Union[RPCEndpoint, Callable[..., RPCEndpoint]], Tuple[RPCEndpoint, ...]],
Tuple[
Union[TReturn, Dict[str, Callable[..., Any]]],
Callable[..., Any],
Union[TReturn, Callable[..., Any]],
],
]:
) -> RequestAndFormatters:
params = self.input_munger(module, args, kwargs)

if self.method_choice_depends_on_args:
Expand All @@ -233,10 +218,12 @@ def process_params(
get_error_formatters(method),
self.null_result_formatters(method),
)
request = (
method,
_apply_request_formatters(params, self.request_formatters(method)),
)

if request_formatters := self.request_formatters(method):
params = request_formatters(params)

request = method, params

return request, response_formatters


Expand Down
57 changes: 41 additions & 16 deletions web3/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
Coroutine,
Dict,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
overload,
)

from eth_abi.codec import (
Expand All @@ -27,14 +25,15 @@
)
from web3.method import (
Method,
RequestAndFormatters,
)
from web3.providers.persistent import (
PersistentConnectionProvider,
)
from web3.types import (
FormattedEthSubscriptionResponse,
RPCEndpoint,
RPCResponse,
TFunc,
)

if TYPE_CHECKING:
Expand All @@ -58,34 +57,50 @@ def apply_result_formatters(
TReturn = TypeVar("TReturn")


@overload
def retrieve_request_information_for_batching(
w3: "AsyncWeb3",
module: "Module",
method: Method[Callable[..., Any]],
) -> Callable[..., Coroutine[Any, Any, RequestAndFormatters]]:
...


@overload
def retrieve_request_information_for_batching(
w3: "Web3",
module: "Module",
method: Method[Callable[..., Any]],
) -> Callable[..., RequestAndFormatters]:
...


@curry
def retrieve_request_information_for_batching(
w3: Union["AsyncWeb3", "Web3"],
module: "Module",
method: Method[Callable[..., Any]],
) -> Union[
Callable[..., Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]],
Callable[..., Coroutine[Any, Any, Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]]],
Callable[..., RequestAndFormatters],
Callable[..., Coroutine[Any, Any, RequestAndFormatters]],
]:
async def async_inner(
*args: Any, **kwargs: Any
) -> Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]:
) -> RequestAndFormatters:
(method_str, params), response_formatters = method.process_params(
module, *args, **kwargs
)
if isinstance(w3.provider, PersistentConnectionProvider):
w3.provider._request_processor.cache_request_information(
None, cast(RPCEndpoint, method_str), params, response_formatters
None, method_str, params, response_formatters
)
return (cast(RPCEndpoint, method_str), params), response_formatters
return (method_str, params), response_formatters

def inner(
*args: Any, **kwargs: Any
) -> Tuple[Tuple[RPCEndpoint, Any], Sequence[Any]]:
def inner(*args: Any, **kwargs: Any) -> RequestAndFormatters:
(method_str, params), response_formatters = method.process_params(
module, *args, **kwargs
)
return (cast(RPCEndpoint, method_str), params), response_formatters
return (method_str, params), response_formatters

return async_inner if module.is_async else inner

Expand Down Expand Up @@ -142,7 +157,7 @@ async def caller(

if isinstance(async_w3.provider, PersistentConnectionProvider):
return await async_w3.manager.socket_request(
cast(RPCEndpoint, method_str),
method_str,
params,
response_formatters=response_formatters,
)
Expand All @@ -167,13 +182,23 @@ async def caller(
class Module:
is_async = False

retrieve_request_information: Callable[
[Method[TFunc]],
Union[
Callable[..., RequestAndFormatters],
Callable[..., Coroutine[Any, Any, RequestAndFormatters]],
],
]

def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None:
if self.is_async:
self.retrieve_caller_fn = retrieve_async_method_call_fn(w3, self)
else:
self.retrieve_caller_fn = retrieve_blocking_method_call_fn(w3, self)
self.retrieve_request_information = retrieve_request_information_for_batching(
w3, self
self.retrieve_request_information = (
retrieve_request_information_for_batching( # type: ignore [call-overload]
w3, self
)
)
self.w3 = w3

Expand Down
11 changes: 7 additions & 4 deletions web3/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@
Timestamp = NewType("Timestamp", int)
Wei = NewType("Wei", int)
Gwei = NewType("Gwei", int)
Formatters = Dict[RPCEndpoint, Callable[..., Any]]
Formatter = Callable[..., TReturn]
Formatters = Dict[RPCEndpoint, Formatter[Any]]
ResponseFormatter = Callable[["RPCResponse"], TReturn]
ResponseFormatters = Dict[RPCEndpoint, ResponseFormatter[Any]]


class AccessListEntry(TypedDict):
Expand Down Expand Up @@ -335,9 +338,9 @@ class CreateAccessListResponse(TypedDict):


class FormattersDict(TypedDict, total=False):
error_formatters: Optional[Formatters]
request_formatters: Optional[Formatters]
result_formatters: Optional[Formatters]
error_formatters: Optional[ResponseFormatters]
request_formatters: Optional[ResponseFormatters]
result_formatters: Optional[ResponseFormatters]


class FilterParams(TypedDict, total=False):
Expand Down