|
| 1 | +# Copyright 2017 MongoDB, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Run the SRV support tests.""" |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import glob |
| 19 | +import json |
| 20 | +import os |
| 21 | +import pathlib |
| 22 | +import sys |
| 23 | + |
| 24 | +sys.path[0:0] = [""] |
| 25 | + |
| 26 | +from test.asynchronous import ( |
| 27 | + AsyncIntegrationTest, |
| 28 | + AsyncPyMongoTestCase, |
| 29 | + async_client_context, |
| 30 | + unittest, |
| 31 | +) |
| 32 | +from test.utils import async_wait_until |
| 33 | + |
| 34 | +from pymongo.common import validate_read_preference_tags |
| 35 | +from pymongo.errors import ConfigurationError |
| 36 | +from pymongo.uri_parser import parse_uri, split_hosts |
| 37 | + |
| 38 | +_IS_SYNC = False |
| 39 | + |
| 40 | + |
| 41 | +class TestDNSRepl(AsyncPyMongoTestCase): |
| 42 | + if _IS_SYNC: |
| 43 | + TEST_PATH = os.path.join( |
| 44 | + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "replica-set" |
| 45 | + ) |
| 46 | + else: |
| 47 | + TEST_PATH = os.path.join( |
| 48 | + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "replica-set" |
| 49 | + ) |
| 50 | + load_balanced = False |
| 51 | + |
| 52 | + @async_client_context.require_replica_set |
| 53 | + def asyncSetUp(self): |
| 54 | + pass |
| 55 | + |
| 56 | + |
| 57 | +class TestDNSLoadBalanced(AsyncPyMongoTestCase): |
| 58 | + if _IS_SYNC: |
| 59 | + TEST_PATH = os.path.join( |
| 60 | + pathlib.Path(__file__).resolve().parent, "srv_seedlist", "load-balanced" |
| 61 | + ) |
| 62 | + else: |
| 63 | + TEST_PATH = os.path.join( |
| 64 | + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "load-balanced" |
| 65 | + ) |
| 66 | + load_balanced = True |
| 67 | + |
| 68 | + @async_client_context.require_load_balancer |
| 69 | + def asyncSetUp(self): |
| 70 | + pass |
| 71 | + |
| 72 | + |
| 73 | +class TestDNSSharded(AsyncPyMongoTestCase): |
| 74 | + if _IS_SYNC: |
| 75 | + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "srv_seedlist", "sharded") |
| 76 | + else: |
| 77 | + TEST_PATH = os.path.join( |
| 78 | + pathlib.Path(__file__).resolve().parent.parent, "srv_seedlist", "sharded" |
| 79 | + ) |
| 80 | + load_balanced = False |
| 81 | + |
| 82 | + @async_client_context.require_mongos |
| 83 | + def asyncSetUp(self): |
| 84 | + pass |
| 85 | + |
| 86 | + |
| 87 | +def create_test(test_case): |
| 88 | + async def run_test(self): |
| 89 | + uri = test_case["uri"] |
| 90 | + seeds = test_case.get("seeds") |
| 91 | + num_seeds = test_case.get("numSeeds", len(seeds or [])) |
| 92 | + hosts = test_case.get("hosts") |
| 93 | + num_hosts = test_case.get("numHosts", len(hosts or [])) |
| 94 | + |
| 95 | + options = test_case.get("options", {}) |
| 96 | + if "ssl" in options: |
| 97 | + options["tls"] = options.pop("ssl") |
| 98 | + parsed_options = test_case.get("parsed_options") |
| 99 | + # See DRIVERS-1324, unless tls is explicitly set to False we need TLS. |
| 100 | + needs_tls = not (options and (options.get("ssl") is False or options.get("tls") is False)) |
| 101 | + if needs_tls and not async_client_context.tls: |
| 102 | + self.skipTest("this test requires a TLS cluster") |
| 103 | + if not needs_tls and async_client_context.tls: |
| 104 | + self.skipTest("this test requires a non-TLS cluster") |
| 105 | + |
| 106 | + if seeds: |
| 107 | + seeds = split_hosts(",".join(seeds)) |
| 108 | + if hosts: |
| 109 | + hosts = frozenset(split_hosts(",".join(hosts))) |
| 110 | + |
| 111 | + if seeds or num_seeds: |
| 112 | + result = parse_uri(uri, validate=True) |
| 113 | + if seeds is not None: |
| 114 | + self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) |
| 115 | + if num_seeds is not None: |
| 116 | + self.assertEqual(len(result["nodelist"]), num_seeds) |
| 117 | + if options: |
| 118 | + opts = result["options"] |
| 119 | + if "readpreferencetags" in opts: |
| 120 | + rpts = validate_read_preference_tags( |
| 121 | + "readPreferenceTags", opts.pop("readpreferencetags") |
| 122 | + ) |
| 123 | + opts["readPreferenceTags"] = rpts |
| 124 | + self.assertEqual(result["options"], options) |
| 125 | + if parsed_options: |
| 126 | + for opt, expected in parsed_options.items(): |
| 127 | + if opt == "user": |
| 128 | + self.assertEqual(result["username"], expected) |
| 129 | + elif opt == "password": |
| 130 | + self.assertEqual(result["password"], expected) |
| 131 | + elif opt == "auth_database" or opt == "db": |
| 132 | + self.assertEqual(result["database"], expected) |
| 133 | + |
| 134 | + hostname = next(iter(async_client_context.client.nodes))[0] |
| 135 | + # The replica set members must be configured as 'localhost'. |
| 136 | + if hostname == "localhost": |
| 137 | + copts = async_client_context.default_client_options.copy() |
| 138 | + # Remove tls since SRV parsing should add it automatically. |
| 139 | + copts.pop("tls", None) |
| 140 | + if async_client_context.tls: |
| 141 | + # Our test certs don't support the SRV hosts used in these |
| 142 | + # tests. |
| 143 | + copts["tlsAllowInvalidHostnames"] = True |
| 144 | + |
| 145 | + client = self.simple_client(uri, **copts) |
| 146 | + if client._options.connect: |
| 147 | + await client.aconnect() |
| 148 | + if num_seeds is not None: |
| 149 | + self.assertEqual(len(client._topology_settings.seeds), num_seeds) |
| 150 | + if hosts is not None: |
| 151 | + await async_wait_until( |
| 152 | + lambda: hosts == client.nodes, "match test hosts to client nodes" |
| 153 | + ) |
| 154 | + if num_hosts is not None: |
| 155 | + await async_wait_until( |
| 156 | + lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts" |
| 157 | + ) |
| 158 | + if test_case.get("ping", True): |
| 159 | + await client.admin.command("ping") |
| 160 | + # XXX: we should block until SRV poller runs at least once |
| 161 | + # and re-run these assertions. |
| 162 | + else: |
| 163 | + try: |
| 164 | + parse_uri(uri) |
| 165 | + except (ConfigurationError, ValueError): |
| 166 | + pass |
| 167 | + else: |
| 168 | + self.fail("failed to raise an exception") |
| 169 | + |
| 170 | + return run_test |
| 171 | + |
| 172 | + |
| 173 | +def create_tests(cls): |
| 174 | + for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")): |
| 175 | + test_suffix, _ = os.path.splitext(os.path.basename(filename)) |
| 176 | + with open(filename) as dns_test_file: |
| 177 | + test_method = create_test(json.load(dns_test_file)) |
| 178 | + setattr(cls, "test_" + test_suffix, test_method) |
| 179 | + |
| 180 | + |
| 181 | +create_tests(TestDNSRepl) |
| 182 | +create_tests(TestDNSLoadBalanced) |
| 183 | +create_tests(TestDNSSharded) |
| 184 | + |
| 185 | + |
| 186 | +class TestParsingErrors(AsyncPyMongoTestCase): |
| 187 | + async def test_invalid_host(self): |
| 188 | + self.assertRaisesRegex( |
| 189 | + ConfigurationError, |
| 190 | + "Invalid URI host: mongodb is not", |
| 191 | + self.simple_client, |
| 192 | + "mongodb+srv://mongodb", |
| 193 | + ) |
| 194 | + self.assertRaisesRegex( |
| 195 | + ConfigurationError, |
| 196 | + "Invalid URI host: mongodb.com is not", |
| 197 | + self.simple_client, |
| 198 | + "mongodb+srv://mongodb.com", |
| 199 | + ) |
| 200 | + self.assertRaisesRegex( |
| 201 | + ConfigurationError, |
| 202 | + "Invalid URI host: an IP address is not", |
| 203 | + self.simple_client, |
| 204 | + "mongodb+srv://127.0.0.1", |
| 205 | + ) |
| 206 | + self.assertRaisesRegex( |
| 207 | + ConfigurationError, |
| 208 | + "Invalid URI host: an IP address is not", |
| 209 | + self.simple_client, |
| 210 | + "mongodb+srv://[::1]", |
| 211 | + ) |
| 212 | + |
| 213 | + |
| 214 | +class IsolatedAsyncioTestCaseInsensitive(AsyncIntegrationTest): |
| 215 | + async def test_connect_case_insensitive(self): |
| 216 | + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") |
| 217 | + self.assertGreater(len(client.topology_description.server_descriptions()), 1) |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == "__main__": |
| 221 | + unittest.main() |
0 commit comments