Skip to content

Commit 2467efc

Browse files
committed
Vendor DataLoader from aiodataloader and also move get_event_loop behavior from __init__ to a property which only gets resolved when actually needed (this will solve PyTest-related early get_event_loop() issues)
1 parent 20219fd commit 2467efc

File tree

3 files changed

+284
-2
lines changed

3 files changed

+284
-2
lines changed

graphene/utils/dataloader.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from asyncio import (
2+
gather,
3+
ensure_future,
4+
get_event_loop,
5+
iscoroutine,
6+
iscoroutinefunction,
7+
)
8+
from collections import namedtuple
9+
from collections.abc import Iterable
10+
from functools import partial
11+
12+
from typing import List # flake8: noqa
13+
14+
Loader = namedtuple("Loader", "key,future")
15+
16+
17+
def iscoroutinefunctionorpartial(fn):
18+
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
19+
20+
21+
class DataLoader(object):
22+
batch = True
23+
max_batch_size = None # type: int
24+
cache = True
25+
26+
def __init__(
27+
self,
28+
batch_load_fn=None,
29+
batch=None,
30+
max_batch_size=None,
31+
cache=None,
32+
get_cache_key=None,
33+
cache_map=None,
34+
loop=None,
35+
):
36+
37+
# Create empty _loop which will be populated with asyncio's event loop as soon as it's needed.
38+
self._loop = None
39+
40+
if batch_load_fn is not None:
41+
self.batch_load_fn = batch_load_fn
42+
43+
assert iscoroutinefunctionorpartial(
44+
self.batch_load_fn
45+
), "batch_load_fn must be coroutine. Received: {}".format(self.batch_load_fn)
46+
47+
if not callable(self.batch_load_fn):
48+
raise TypeError(
49+
(
50+
"DataLoader must be have a batch_load_fn which accepts "
51+
"Iterable<key> and returns Future<Iterable<value>>, but got: {}."
52+
).format(batch_load_fn)
53+
)
54+
55+
if batch is not None:
56+
self.batch = batch
57+
58+
if max_batch_size is not None:
59+
self.max_batch_size = max_batch_size
60+
61+
if cache is not None:
62+
self.cache = cache
63+
64+
self.get_cache_key = get_cache_key or (lambda x: x)
65+
66+
self._cache = cache_map if cache_map is not None else {}
67+
self._queue = [] # type: List[Loader]
68+
69+
@property
70+
def loop(self):
71+
if not self._loop:
72+
self._loop = get_event_loop()
73+
74+
return self._loop
75+
76+
def load(self, key=None):
77+
"""
78+
Loads a key, returning a `Future` for the value represented by that key.
79+
"""
80+
if key is None:
81+
raise TypeError(
82+
(
83+
"The loader.load() function must be called with a value, "
84+
"but got: {}."
85+
).format(key)
86+
)
87+
88+
cache_key = self.get_cache_key(key)
89+
90+
# If caching and there is a cache-hit, return cached Future.
91+
if self.cache:
92+
cached_result = self._cache.get(cache_key)
93+
if cached_result:
94+
return cached_result
95+
96+
# Otherwise, produce a new Future for this value.
97+
future = self.loop.create_future()
98+
# If caching, cache this Future.
99+
if self.cache:
100+
self._cache[cache_key] = future
101+
102+
self.do_resolve_reject(key, future)
103+
return future
104+
105+
def do_resolve_reject(self, key, future):
106+
# Enqueue this Future to be dispatched.
107+
self._queue.append(Loader(key=key, future=future))
108+
# Determine if a dispatch of this queue should be scheduled.
109+
# A single dispatch should be scheduled per queue at the time when the
110+
# queue changes from "empty" to "full".
111+
if len(self._queue) == 1:
112+
if self.batch:
113+
# If batching, schedule a task to dispatch the queue.
114+
enqueue_post_future_job(self.loop, self)
115+
else:
116+
# Otherwise dispatch the (queue of one) immediately.
117+
dispatch_queue(self)
118+
119+
def load_many(self, keys):
120+
"""
121+
Loads multiple keys, returning a list of values
122+
123+
>>> a, b = await my_loader.load_many([ 'a', 'b' ])
124+
125+
This is equivalent to the more verbose:
126+
127+
>>> a, b = await gather(
128+
>>> my_loader.load('a'),
129+
>>> my_loader.load('b')
130+
>>> )
131+
"""
132+
if not isinstance(keys, Iterable):
133+
raise TypeError(
134+
(
135+
"The loader.load_many() function must be called with Iterable<key> "
136+
"but got: {}."
137+
).format(keys)
138+
)
139+
140+
return gather(*[self.load(key) for key in keys])
141+
142+
def clear(self, key):
143+
"""
144+
Clears the value at `key` from the cache, if it exists. Returns itself for
145+
method chaining.
146+
"""
147+
cache_key = self.get_cache_key(key)
148+
self._cache.pop(cache_key, None)
149+
return self
150+
151+
def clear_all(self):
152+
"""
153+
Clears the entire cache. To be used when some event results in unknown
154+
invalidations across this particular `DataLoader`. Returns itself for
155+
method chaining.
156+
"""
157+
self._cache.clear()
158+
return self
159+
160+
def prime(self, key, value):
161+
"""
162+
Adds the provied key and value to the cache. If the key already exists, no
163+
change is made. Returns itself for method chaining.
164+
"""
165+
cache_key = self.get_cache_key(key)
166+
167+
# Only add the key if it does not already exist.
168+
if cache_key not in self._cache:
169+
# Cache a rejected future if the value is an Error, in order to match
170+
# the behavior of load(key).
171+
future = self.loop.create_future()
172+
if isinstance(value, Exception):
173+
future.set_exception(value)
174+
else:
175+
future.set_result(value)
176+
177+
self._cache[cache_key] = future
178+
179+
return self
180+
181+
182+
def enqueue_post_future_job(loop, loader):
183+
async def dispatch():
184+
dispatch_queue(loader)
185+
186+
loop.call_soon(ensure_future, dispatch())
187+
188+
189+
def get_chunks(iterable_obj, chunk_size=1):
190+
chunk_size = max(1, chunk_size)
191+
return (
192+
iterable_obj[i : i + chunk_size]
193+
for i in range(0, len(iterable_obj), chunk_size)
194+
)
195+
196+
197+
def dispatch_queue(loader):
198+
"""
199+
Given the current state of a Loader instance, perform a batch load
200+
from its current queue.
201+
"""
202+
# Take the current loader queue, replacing it with an empty queue.
203+
queue = loader._queue
204+
loader._queue = []
205+
206+
# If a max_batch_size was provided and the queue is longer, then segment the
207+
# queue into multiple batches, otherwise treat the queue as a single batch.
208+
max_batch_size = loader.max_batch_size
209+
210+
if max_batch_size and max_batch_size < len(queue):
211+
chunks = get_chunks(queue, max_batch_size)
212+
for chunk in chunks:
213+
ensure_future(dispatch_queue_batch(loader, chunk))
214+
else:
215+
ensure_future(dispatch_queue_batch(loader, queue))
216+
217+
218+
async def dispatch_queue_batch(loader, queue):
219+
# Collect all keys to be loaded in this dispatch
220+
keys = [loaded.key for loaded in queue]
221+
222+
# Call the provided batch_load_fn for this loader with the loader queue's keys.
223+
batch_future = loader.batch_load_fn(keys)
224+
225+
# Assert the expected response from batch_load_fn
226+
if not batch_future or not iscoroutine(batch_future):
227+
return failed_dispatch(
228+
loader,
229+
queue,
230+
TypeError(
231+
(
232+
"DataLoader must be constructed with a function which accepts "
233+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
234+
"not return a Coroutine: {}."
235+
).format(batch_future)
236+
),
237+
)
238+
239+
try:
240+
values = await batch_future
241+
if not isinstance(values, Iterable):
242+
raise TypeError(
243+
(
244+
"DataLoader must be constructed with a function which accepts "
245+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
246+
"not return a Future of a Iterable: {}."
247+
).format(values)
248+
)
249+
250+
values = list(values)
251+
if len(values) != len(keys):
252+
raise TypeError(
253+
(
254+
"DataLoader must be constructed with a function which accepts "
255+
"Iterable<key> and returns Future<Iterable<value>>, but the function did "
256+
"not return a Future of a Iterable with the same length as the Iterable "
257+
"of keys."
258+
"\n\nKeys:\n{}"
259+
"\n\nValues:\n{}"
260+
).format(keys, values)
261+
)
262+
263+
# Step through the values, resolving or rejecting each Future in the
264+
# loaded queue.
265+
for loaded, value in zip(queue, values):
266+
if isinstance(value, Exception):
267+
loaded.future.set_exception(value)
268+
else:
269+
loaded.future.set_result(value)
270+
271+
except Exception as e:
272+
return failed_dispatch(loader, queue, e)
273+
274+
275+
def failed_dispatch(loader, queue, error):
276+
"""
277+
Do not cache individual loads if the entire batch dispatch fails,
278+
but still reject each request so they do not hang.
279+
"""
280+
for loaded in queue:
281+
loader.clear(loaded.key)
282+
loaded.future.set_exception(error)

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def run_tests(self):
5353
"snapshottest>=0.6,<1",
5454
"coveralls>=3.3,<4",
5555
"promise>=2.3,<3",
56-
"aiodataloader<1",
5756
"mock>=4,<5",
5857
"pytz==2022.1",
5958
"iso8601>=1,<2",

tests_asyncio/test_dataloader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections import namedtuple
22
from unittest.mock import Mock
3+
4+
from graphene.utils.dataloader import DataLoader
35
from pytest import mark
4-
from aiodataloader import DataLoader
56

67
from graphene import ObjectType, String, Schema, Field, List
78

0 commit comments

Comments
 (0)