Skip to content

Commit 7e77c3b

Browse files
async support (work in progress)
1 parent 7c8d668 commit 7e77c3b

File tree

8 files changed

+1284
-922
lines changed

8 files changed

+1284
-922
lines changed

elasticsearch_dsl/_async/search.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch.exceptions import ApiError
19+
from elasticsearch.helpers import async_scan
20+
21+
from ..async_connections import get_connection
22+
from ..base_search import BaseMultiSearch, BaseSearch
23+
from ..response import Response
24+
from ..utils import AttrDict
25+
26+
27+
class Search(BaseSearch):
28+
def __aiter__(self):
29+
"""
30+
Iterate over the hits.
31+
"""
32+
return aiter(await self.execute())
33+
34+
async def count(self):
35+
"""
36+
Return the number of hits matching the query and filters. Note that
37+
only the actual number is returned.
38+
"""
39+
if hasattr(self, "_response") and self._response.hits.total.relation == "eq":
40+
return self._response.hits.total.value
41+
42+
es = get_connection(self._using)
43+
44+
d = self.to_dict(count=True)
45+
# TODO: failed shards detection
46+
resp = await es.count(
47+
index=self._index, query=d.get("query", None), **self._params
48+
)
49+
return resp["count"]
50+
51+
async def execute(self, ignore_cache=False):
52+
"""
53+
Execute the search and return an instance of ``Response`` wrapping all
54+
the data.
55+
56+
:arg ignore_cache: if set to ``True``, consecutive calls will hit
57+
ES, while cached result will be ignored. Defaults to `False`
58+
"""
59+
if ignore_cache or not hasattr(self, "_response"):
60+
es = get_connection(self._using)
61+
62+
self._response = self._response_class(
63+
self,
64+
(
65+
await es.search(
66+
index=self._index, body=self.to_dict(), **self._params
67+
)
68+
).body,
69+
)
70+
return self._response
71+
72+
async def scan(self):
73+
"""
74+
Turn the search into a scan search and return a generator that will
75+
iterate over all the documents matching the query.
76+
77+
Use ``params`` method to specify any additional arguments you with to
78+
pass to the underlying ``scan`` helper from ``elasticsearch-py`` -
79+
https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan
80+
81+
"""
82+
es = get_connection(self._using)
83+
84+
async for hit in async_scan(
85+
es, query=self.to_dict(), index=self._index, **self._params
86+
):
87+
yield self._get_result(hit)
88+
89+
async def delete(self):
90+
"""
91+
delete() executes the query by delegating to delete_by_query()
92+
"""
93+
94+
es = get_connection(self._using)
95+
96+
return AttrDict(
97+
await es.delete_by_query(
98+
index=self._index, body=self.to_dict(), **self._params
99+
)
100+
)
101+
102+
103+
class MultiSearch(BaseMultiSearch):
104+
"""
105+
Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single
106+
request.
107+
"""
108+
109+
async def execute(self, ignore_cache=False, raise_on_error=True):
110+
"""
111+
Execute the multi search request and return a list of search results.
112+
"""
113+
if ignore_cache or not hasattr(self, "_response"):
114+
es = get_connection(self._using)
115+
116+
responses = await es.msearch(
117+
index=self._index, body=self.to_dict(), **self._params
118+
)
119+
120+
out = []
121+
for s, r in zip(self._searches, responses["responses"]):
122+
if r.get("error", False):
123+
if raise_on_error:
124+
raise ApiError("N/A", meta=responses.meta, body=r)
125+
r = None
126+
else:
127+
r = Response(s, r)
128+
out.append(r)
129+
130+
self._response = out
131+
132+
return self._response

elasticsearch_dsl/_sync/search.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch.exceptions import ApiError
19+
from elasticsearch.helpers import scan
20+
21+
from ..base_search import BaseMultiSearch, BaseSearch
22+
from ..connections import get_connection
23+
from ..response import Response
24+
from ..utils import AttrDict
25+
26+
27+
class Search(BaseSearch):
28+
def __iter__(self):
29+
"""
30+
Iterate over the hits.
31+
"""
32+
return iter(self.execute())
33+
34+
def count(self):
35+
"""
36+
Return the number of hits matching the query and filters. Note that
37+
only the actual number is returned.
38+
"""
39+
if hasattr(self, "_response") and self._response.hits.total.relation == "eq":
40+
return self._response.hits.total.value
41+
42+
es = get_connection(self._using)
43+
44+
d = self.to_dict(count=True)
45+
# TODO: failed shards detection
46+
resp = es.count(index=self._index, query=d.get("query", None), **self._params)
47+
return resp["count"]
48+
49+
def execute(self, ignore_cache=False):
50+
"""
51+
Execute the search and return an instance of ``Response`` wrapping all
52+
the data.
53+
54+
:arg ignore_cache: if set to ``True``, consecutive calls will hit
55+
ES, while cached result will be ignored. Defaults to `False`
56+
"""
57+
if ignore_cache or not hasattr(self, "_response"):
58+
es = get_connection(self._using)
59+
60+
self._response = self._response_class(
61+
self,
62+
(
63+
es.search(index=self._index, body=self.to_dict(), **self._params)
64+
).body,
65+
)
66+
return self._response
67+
68+
def scan(self):
69+
"""
70+
Turn the search into a scan search and return a generator that will
71+
iterate over all the documents matching the query.
72+
73+
Use ``params`` method to specify any additional arguments you with to
74+
pass to the underlying ``scan`` helper from ``elasticsearch-py`` -
75+
https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan
76+
77+
"""
78+
es = get_connection(self._using)
79+
80+
for hit in scan(es, query=self.to_dict(), index=self._index, **self._params):
81+
yield self._get_result(hit)
82+
83+
def delete(self):
84+
"""
85+
delete() executes the query by delegating to delete_by_query()
86+
"""
87+
88+
es = get_connection(self._using)
89+
90+
return AttrDict(
91+
es.delete_by_query(index=self._index, body=self.to_dict(), **self._params)
92+
)
93+
94+
95+
class MultiSearch(BaseMultiSearch):
96+
"""
97+
Combine multiple :class:`~elasticsearch_dsl.Search` objects into a single
98+
request.
99+
"""
100+
101+
def execute(self, ignore_cache=False, raise_on_error=True):
102+
"""
103+
Execute the multi search request and return a list of search results.
104+
"""
105+
if ignore_cache or not hasattr(self, "_response"):
106+
es = get_connection(self._using)
107+
108+
responses = es.msearch(
109+
index=self._index, body=self.to_dict(), **self._params
110+
)
111+
112+
out = []
113+
for s, r in zip(self._searches, responses["responses"]):
114+
if r.get("error", False):
115+
if raise_on_error:
116+
raise ApiError("N/A", meta=responses.meta, body=r)
117+
r = None
118+
else:
119+
r = Response(s, r)
120+
out.append(r)
121+
122+
self._response = out
123+
124+
return self._response
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch import AsyncElasticsearch
19+
20+
from elasticsearch_dsl.connections import Connections
21+
22+
connections = Connections(elasticsearch_class=AsyncElasticsearch)
23+
configure = connections.configure
24+
add_connection = connections.add_connection
25+
remove_connection = connections.remove_connection
26+
create_connection = connections.create_connection
27+
get_connection = connections.get_connection

0 commit comments

Comments
 (0)