|
18 | 18 |
|
19 | 19 | import os
|
20 | 20 | import re
|
| 21 | +import time |
21 | 22 | from datetime import datetime
|
| 23 | +from unittest import SkipTest, TestCase |
22 | 24 | from unittest.mock import Mock
|
23 | 25 |
|
| 26 | +from elasticsearch import Elasticsearch |
| 27 | +from elasticsearch.exceptions import ConnectionError |
24 | 28 | from elasticsearch.helpers import bulk
|
25 |
| -from elasticsearch.helpers.test import SkipTest, get_test_client |
26 | 29 | from pytest import fixture, skip
|
27 | 30 |
|
28 | 31 | from elasticsearch_dsl.connections import add_connection, connections
|
|
36 | 39 | )
|
37 | 40 | from .test_integration.test_document import Comment, History, PullRequest, User
|
38 | 41 |
|
| 42 | +if "ELASTICSEARCH_URL" in os.environ: |
| 43 | + ELASTICSEARCH_URL = os.environ["ELASTICSEARCH_URL"] |
| 44 | +else: |
| 45 | + ELASTICSEARCH_URL = "http://localhost:9200" |
| 46 | + |
| 47 | + |
| 48 | +def get_test_client(wait=True, **kwargs): |
| 49 | + # construct kwargs from the environment |
| 50 | + kw = {"timeout": 30} |
| 51 | + |
| 52 | + if "PYTHON_CONNECTION_CLASS" in os.environ: |
| 53 | + from elasticsearch import connection |
| 54 | + |
| 55 | + kw["connection_class"] = getattr( |
| 56 | + connection, os.environ["PYTHON_CONNECTION_CLASS"] |
| 57 | + ) |
| 58 | + |
| 59 | + kw.update(kwargs) |
| 60 | + client = Elasticsearch(ELASTICSEARCH_URL, **kw) |
| 61 | + |
| 62 | + # wait for yellow status |
| 63 | + for tries_left in range(100 if wait else 1, 0, -1): |
| 64 | + try: |
| 65 | + client.cluster.health(wait_for_status="yellow") |
| 66 | + return client |
| 67 | + except ConnectionError: |
| 68 | + if wait and tries_left == 1: |
| 69 | + raise |
| 70 | + time.sleep(0.1) |
| 71 | + |
| 72 | + raise SkipTest("Elasticsearch failed to start.") |
| 73 | + |
| 74 | + |
| 75 | +class ElasticsearchTestCase(TestCase): |
| 76 | + @staticmethod |
| 77 | + def _get_client(): |
| 78 | + return get_test_client() |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def setup_class(cls): |
| 82 | + cls.client = cls._get_client() |
| 83 | + |
| 84 | + def teardown_method(self, _): |
| 85 | + # Hidden indices expanded in wildcards in ES 7.7 |
| 86 | + expand_wildcards = ["open", "closed"] |
| 87 | + if self.es_version() >= (7, 7): |
| 88 | + expand_wildcards.append("hidden") |
| 89 | + |
| 90 | + self.client.indices.delete_data_stream( |
| 91 | + name="*", ignore=404, expand_wildcards=expand_wildcards |
| 92 | + ) |
| 93 | + self.client.indices.delete( |
| 94 | + index="*", ignore=404, expand_wildcards=expand_wildcards |
| 95 | + ) |
| 96 | + self.client.indices.delete_template(name="*", ignore=404) |
| 97 | + |
| 98 | + def es_version(self): |
| 99 | + if not hasattr(self, "_es_version"): |
| 100 | + self._es_version = _get_version(client.info()["version"]["number"]) |
| 101 | + return self._es_version |
| 102 | + |
| 103 | + |
| 104 | +def _get_version(version_string): |
| 105 | + if "." not in version_string: |
| 106 | + return () |
| 107 | + version = version_string.strip().split(".") |
| 108 | + return tuple(int(v) if v.isdigit() else 999 for v in version) |
| 109 | + |
39 | 110 |
|
40 | 111 | @fixture(scope="session")
|
41 | 112 | def client():
|
42 | 113 | try:
|
43 |
| - connection = get_test_client(nowait="WAIT_FOR_ES" not in os.environ) |
| 114 | + connection = get_test_client(wait="WAIT_FOR_ES" in os.environ) |
44 | 115 | add_connection("default", connection)
|
45 | 116 | return connection
|
46 | 117 | except SkipTest:
|
|
0 commit comments