diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0eade8354..5cd05d38f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,3 +92,5 @@ jobs: - name: Run Tests run: | nox -rs test-${{ matrix.python-version }} + env: + WAIT_FOR_ES: "1" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 09c092890..8400ed013 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,11 +18,14 @@ import os import re +import time from datetime import datetime +from unittest import SkipTest, TestCase from unittest.mock import Mock +from elasticsearch import Elasticsearch +from elasticsearch.exceptions import ConnectionError from elasticsearch.helpers import bulk -from elasticsearch.helpers.test import SkipTest, get_test_client from pytest import fixture, skip from elasticsearch_dsl.connections import add_connection, connections @@ -36,11 +39,79 @@ ) from .test_integration.test_document import Comment, History, PullRequest, User +if "ELASTICSEARCH_URL" in os.environ: + ELASTICSEARCH_URL = os.environ["ELASTICSEARCH_URL"] +else: + ELASTICSEARCH_URL = "http://localhost:9200" + + +def get_test_client(wait=True, **kwargs): + # construct kwargs from the environment + kw = {"timeout": 30} + + if "PYTHON_CONNECTION_CLASS" in os.environ: + from elasticsearch import connection + + kw["connection_class"] = getattr( + connection, os.environ["PYTHON_CONNECTION_CLASS"] + ) + + kw.update(kwargs) + client = Elasticsearch(ELASTICSEARCH_URL, **kw) + + # wait for yellow status + for tries_left in range(100 if wait else 1, 0, -1): + try: + client.cluster.health(wait_for_status="yellow") + return client + except ConnectionError: + if wait and tries_left == 1: + raise + time.sleep(0.1) + + raise SkipTest("Elasticsearch failed to start.") + + +class ElasticsearchTestCase(TestCase): + @staticmethod + def _get_client(): + return get_test_client() + + @classmethod + def setup_class(cls): + cls.client = cls._get_client() + + def teardown_method(self, _): + # Hidden indices expanded in wildcards in ES 7.7 + expand_wildcards = ["open", "closed"] + if self.es_version() >= (7, 7): + expand_wildcards.append("hidden") + + self.client.indices.delete_data_stream( + name="*", ignore=404, expand_wildcards=expand_wildcards + ) + self.client.indices.delete( + index="*", ignore=404, expand_wildcards=expand_wildcards + ) + self.client.indices.delete_template(name="*", ignore=404) + + def es_version(self): + if not hasattr(self, "_es_version"): + self._es_version = _get_version(client.info()["version"]["number"]) + return self._es_version + + +def _get_version(version_string): + if "." not in version_string: + return () + version = version_string.strip().split(".") + return tuple(int(v) if v.isdigit() else 999 for v in version) + @fixture(scope="session") def client(): try: - connection = get_test_client(nowait="WAIT_FOR_ES" not in os.environ) + connection = get_test_client(wait="WAIT_FOR_ES" in os.environ) add_connection("default", connection) return connection except SkipTest: