Skip to content

Commit 8add69f

Browse files
async support for analyzer
1 parent 0184a77 commit 8add69f

File tree

4 files changed

+90
-19
lines changed

4 files changed

+90
-19
lines changed

elasticsearch_dsl/analysis.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from .connections import get_connection
18+
from . import async_connections, connections
1919
from .utils import AttrDict, DslBase, merge
2020

2121
__all__ = ["tokenizer", "analyzer", "char_filter", "token_filter", "normalizer"]
@@ -119,20 +119,7 @@ class CustomAnalyzer(CustomAnalysisDefinition, Analyzer):
119119
"tokenizer": {"type": "tokenizer"},
120120
}
121121

122-
def simulate(self, text, using="default", explain=False, attributes=None):
123-
"""
124-
Use the Analyze API of elasticsearch to test the outcome of this analyzer.
125-
126-
:arg text: Text to be analyzed
127-
:arg using: connection alias to use, defaults to ``'default'``
128-
:arg explain: will output all token attributes for each token. You can
129-
filter token attributes you want to output by setting ``attributes``
130-
option.
131-
:arg attributes: if ``explain`` is specified, filter the token
132-
attributes to return.
133-
"""
134-
es = get_connection(using)
135-
122+
def _get_body(self, text, explain, attributes):
136123
body = {"text": text, "explain": explain}
137124
if attributes:
138125
body["attributes"] = attributes
@@ -156,7 +143,43 @@ def simulate(self, text, using="default", explain=False, attributes=None):
156143
if self._builtin_type != "custom":
157144
body["analyzer"] = self._builtin_type
158145

159-
return AttrDict(es.indices.analyze(body=body))
146+
return body
147+
148+
def simulate(self, text, using="default", explain=False, attributes=None):
149+
"""
150+
Use the Analyze API of elasticsearch to test the outcome of this analyzer.
151+
152+
:arg text: Text to be analyzed
153+
:arg using: connection alias to use, defaults to ``'default'``
154+
:arg explain: will output all token attributes for each token. You can
155+
filter token attributes you want to output by setting ``attributes``
156+
option.
157+
:arg attributes: if ``explain`` is specified, filter the token
158+
attributes to return.
159+
"""
160+
es = connections.get_connection(using)
161+
return AttrDict(
162+
es.indices.analyze(body=self._get_body(text, explain, attributes))
163+
)
164+
165+
async def async_simulate(
166+
self, text, using="default", explain=False, attributes=None
167+
):
168+
"""
169+
Use the Analyze API of elasticsearch to test the outcome of this analyzer.
170+
171+
:arg text: Text to be analyzed
172+
:arg using: connection alias to use, defaults to ``'default'``
173+
:arg explain: will output all token attributes for each token. You can
174+
filter token attributes you want to output by setting ``attributes``
175+
option.
176+
:arg attributes: if ``explain`` is specified, filter the token
177+
attributes to return.
178+
"""
179+
es = async_connections.get_connection(using)
180+
return AttrDict(
181+
await es.indices.analyze(body=self._get_body(text, explain, attributes))
182+
)
160183

161184

162185
class Normalizer(AnalysisBase, DslBase):
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from elasticsearch_dsl import analyzer, token_filter, tokenizer
19+
20+
21+
async def test_simulate_with_just__builtin_tokenizer(async_client):
22+
a = analyzer("my-analyzer", tokenizer="keyword")
23+
tokens = (await a.async_simulate("Hello World!", using=async_client)).tokens
24+
25+
assert len(tokens) == 1
26+
assert tokens[0].token == "Hello World!"
27+
28+
29+
async def test_simulate_complex(async_client):
30+
a = analyzer(
31+
"my-analyzer",
32+
tokenizer=tokenizer("split_words", "simple_pattern_split", pattern=":"),
33+
filter=["lowercase", token_filter("no-ifs", "stop", stopwords=["if"])],
34+
)
35+
36+
tokens = (await a.async_simulate("if:this:works", using=async_client)).tokens
37+
38+
assert len(tokens) == 2
39+
assert ["this", "works"] == [t.token for t in tokens]
40+
41+
42+
async def test_simulate_builtin(async_client):
43+
a = analyzer("my-analyzer", "english")
44+
tokens = (await a.async_simulate("fixes running")).tokens
45+
46+
assert ["fix", "run"] == [t.token for t in tokens]

tests/test_integration/test_analysis.py renamed to tests/test_integration/_sync/test_analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def test_simulate_with_just__builtin_tokenizer(client):
2222
a = analyzer("my-analyzer", tokenizer="keyword")
23-
tokens = a.simulate("Hello World!", using=client).tokens
23+
tokens = (a.simulate("Hello World!", using=client)).tokens
2424

2525
assert len(tokens) == 1
2626
assert tokens[0].token == "Hello World!"
@@ -33,14 +33,14 @@ def test_simulate_complex(client):
3333
filter=["lowercase", token_filter("no-ifs", "stop", stopwords=["if"])],
3434
)
3535

36-
tokens = a.simulate("if:this:works", using=client).tokens
36+
tokens = (a.simulate("if:this:works", using=client)).tokens
3737

3838
assert len(tokens) == 2
3939
assert ["this", "works"] == [t.token for t in tokens]
4040

4141

4242
def test_simulate_builtin(client):
4343
a = analyzer("my-analyzer", "english")
44-
tokens = a.simulate("fixes running").tokens
44+
tokens = (a.simulate("fixes running")).tokens
4545

4646
assert ["fix", "run"] == [t.token for t in tokens]

utils/run-unasync.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def main(check=False):
4343
"AsyncFacetedSearch": "FacetedSearch",
4444
"async_connections": "connections",
4545
"async_scan": "scan",
46+
"async_simulate": "simulate",
4647
"async_mock_client": "mock_client",
48+
"async_client": "client",
4749
"async_data_client": "data_client",
4850
"async_write_client": "write_client",
4951
"async_pull_request": "pull_request",

0 commit comments

Comments
 (0)