Skip to content

Commit e0b8b36

Browse files
authored
PYTHON-3813 add types to pool.py (#1318)
1 parent 5484075 commit e0b8b36

15 files changed

+313
-204
lines changed

pymongo/auth.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,15 @@ def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanis
239239

240240
ctx = conn.auth_ctx
241241
if ctx and ctx.speculate_succeeded():
242+
assert isinstance(ctx, _ScramContext)
243+
assert ctx.scram_data is not None
242244
nonce, first_bare = ctx.scram_data
243245
res = ctx.speculative_authenticate
244246
else:
245247
nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism)
246248
res = conn.command(source, cmd)
247249

250+
assert res is not None
248251
server_first = res["payload"]
249252
parsed = _parse_scram_response(server_first)
250253
iterations = int(parsed[b"i"])
@@ -575,7 +578,7 @@ def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
575578

576579

577580
class _X509Context(_AuthContext):
578-
def speculate_command(self) -> Optional[MutableMapping[str, Any]]:
581+
def speculate_command(self) -> MutableMapping[str, Any]:
579582
cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")])
580583
if self.credentials.username is not None:
581584
cmd["user"] = self.credentials.username

pymongo/auth_oidc.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@
1919
import threading
2020
from dataclasses import dataclass, field
2121
from datetime import datetime, timedelta, timezone
22-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple
22+
from typing import (
23+
TYPE_CHECKING,
24+
Any,
25+
Callable,
26+
Dict,
27+
List,
28+
Mapping,
29+
MutableMapping,
30+
Optional,
31+
Tuple,
32+
)
2333

2434
import bson
2535
from bson.binary import Binary
@@ -242,7 +252,9 @@ def clear(self) -> None:
242252
self.idp_resp = None
243253
self.token_exp_utc = None
244254

245-
def run_command(self, conn: Connection, cmd: Mapping[str, Any]) -> Optional[Mapping[str, Any]]:
255+
def run_command(
256+
self, conn: Connection, cmd: MutableMapping[str, Any]
257+
) -> Optional[Mapping[str, Any]]:
246258
try:
247259
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
248260
except OperationFailure as exc:
@@ -276,6 +288,7 @@ def authenticate(
276288
assert cmd is not None
277289
resp = self.run_command(conn, cmd)
278290

291+
assert resp is not None
279292
if resp["done"]:
280293
conn.oidc_token_gen_id = self.token_gen_id
281294
return None
@@ -297,6 +310,7 @@ def authenticate(
297310
]
298311
)
299312
resp = self.run_command(conn, cmd)
313+
assert resp is not None
300314
if not resp["done"]:
301315
self.clear()
302316
raise OperationFailure("SASL conversation failed to complete.")

pymongo/client_options.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Tools to parse mongo client options."""
1616
from __future__ import annotations
1717

18-
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, cast
18+
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence, Tuple, cast
1919

2020
from bson.codec_options import _parse_codec_options
2121
from pymongo import common
@@ -309,11 +309,12 @@ def load_balanced(self) -> Optional[bool]:
309309
return self.__load_balanced
310310

311311
@property
312-
def event_listeners(self) -> _EventListeners:
312+
def event_listeners(self) -> List[_EventListeners]:
313313
"""The event listeners registered for this client.
314314
315315
See :mod:`~pymongo.monitoring` for details.
316316
317317
.. versionadded:: 4.0
318318
"""
319+
assert self.__pool_options._event_listeners is not None
319320
return self.__pool_options._event_listeners.event_listeners()

pymongo/collection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
TYPE_CHECKING,
2121
Any,
2222
Callable,
23-
Container,
2423
ContextManager,
2524
Generic,
2625
Iterable,
@@ -268,11 +267,11 @@ def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[C
268267
def _command(
269268
self,
270269
conn: Connection,
271-
command: Mapping[str, Any],
270+
command: MutableMapping[str, Any],
272271
read_preference: Optional[_ServerMode] = None,
273272
codec_options: Optional[CodecOptions] = None,
274273
check: bool = True,
275-
allowable_errors: Optional[Container[Any]] = None,
274+
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
276275
read_concern: Optional[ReadConcern] = None,
277276
write_concern: Optional[WriteConcern] = None,
278277
collation: Optional[_CollationIn] = None,
@@ -1753,7 +1752,7 @@ def _count_cmd(
17531752
session: Optional[ClientSession],
17541753
conn: Connection,
17551754
read_preference: Optional[_ServerMode],
1756-
cmd: Mapping[str, Any],
1755+
cmd: SON[str, Any],
17571756
collation: Optional[Collation],
17581757
) -> int:
17591758
"""Internal count command helper."""
@@ -1777,7 +1776,7 @@ def _aggregate_one_result(
17771776
self,
17781777
conn: Connection,
17791778
read_preference: Optional[_ServerMode],
1780-
cmd: Mapping[str, Any],
1779+
cmd: SON[str, Any],
17811780
collation: Optional[_CollationIn],
17821781
session: Optional[ClientSession],
17831782
) -> Optional[Mapping[str, Any]]:

pymongo/command_cursor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Mapping,
2626
NoReturn,
2727
Optional,
28+
Sequence,
2829
Union,
2930
)
3031

@@ -220,7 +221,7 @@ def _unpack_response(
220221
codec_options: CodecOptions[Mapping[str, Any]],
221222
user_fields: Optional[Mapping[str, Any]] = None,
222223
legacy_response: bool = False,
223-
) -> List[_DocumentOut]:
224+
) -> Sequence[_DocumentOut]:
224225
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
225226

226227
def _refresh(self) -> int:

pymongo/compression_support.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import warnings
17-
from typing import Any, Iterable, List, Union
17+
from typing import Any, Iterable, List, Optional, Union
1818

1919
try:
2020
import snappy
@@ -96,7 +96,7 @@ def __init__(self, compressors: List[str], zlib_compression_level: int):
9696
self.zlib_compression_level = zlib_compression_level
9797

9898
def get_compression_context(
99-
self, compressors: List[str]
99+
self, compressors: Optional[List[str]]
100100
) -> Union[SnappyContext, ZlibContext, ZstdContext, None]:
101101
if compressors:
102102
chosen = compressors[0]

pymongo/cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,7 @@ def _unpack_response(
11301130
codec_options: CodecOptions,
11311131
user_fields: Optional[Mapping[str, Any]] = None,
11321132
legacy_response: bool = False,
1133-
) -> List[_DocumentOut]:
1133+
) -> Sequence[_DocumentOut]:
11341134
return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response)
11351135

11361136
def _read_preference(self) -> _ServerMode:

pymongo/database.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def _command(
694694
value: int = 1,
695695
check: bool = True,
696696
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
697-
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
697+
read_preference: _ServerMode = ReadPreference.PRIMARY,
698698
codec_options: CodecOptions[Dict[str, Any]] = DEFAULT_CODEC_OPTIONS,
699699
write_concern: Optional[WriteConcern] = None,
700700
parse_write_concern_error: bool = False,
@@ -711,7 +711,7 @@ def _command(
711711
value: int = 1,
712712
check: bool = True,
713713
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
714-
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
714+
read_preference: _ServerMode = ReadPreference.PRIMARY,
715715
codec_options: CodecOptions[_CodecDocumentType] = ...,
716716
write_concern: Optional[WriteConcern] = None,
717717
parse_write_concern_error: bool = False,
@@ -727,7 +727,7 @@ def _command(
727727
value: int = 1,
728728
check: bool = True,
729729
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
730-
read_preference: Optional[_ServerMode] = ReadPreference.PRIMARY,
730+
read_preference: _ServerMode = ReadPreference.PRIMARY,
731731
codec_options: Union[
732732
CodecOptions[Dict[str, Any]], CodecOptions[_CodecDocumentType]
733733
] = DEFAULT_CODEC_OPTIONS,

pymongo/message.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ def unack_write(
10181018
cmd = self._start(cmd, request_id, docs)
10191019
start = datetime.datetime.now()
10201020
try:
1021-
result = self.conn.unack_write(msg, max_doc_size)
1021+
result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value]
10221022
if self.publish:
10231023
duration = (datetime.datetime.now() - start) + duration
10241024
if result is not None:
@@ -1050,7 +1050,7 @@ def write_command(
10501050
request_id: int,
10511051
msg: bytes,
10521052
docs: List[Mapping[str, Any]],
1053-
) -> Mapping[str, Any]:
1053+
) -> Dict[str, Any]:
10541054
"""A proxy for SocketInfo.write_command that handles event publishing."""
10551055
if self.publish:
10561056
assert self.start_time is not None
@@ -1127,7 +1127,7 @@ class _EncryptedBulkWriteContext(_BulkWriteContext):
11271127

11281128
def __batch_command(
11291129
self, cmd: MutableMapping[str, Any], docs: List[Mapping[str, Any]]
1130-
) -> Tuple[Mapping[str, Any], List[Mapping[str, Any]]]:
1130+
) -> Tuple[Dict[str, Any], List[Mapping[str, Any]]]:
11311131
namespace = self.db_name + ".$cmd"
11321132
msg, to_send = _encode_batched_write_command(
11331133
namespace, self.op_type, cmd, docs, self.codec, self
@@ -1517,7 +1517,7 @@ def unpack_response(
15171517
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
15181518
user_fields: Optional[Mapping[str, Any]] = None,
15191519
legacy_response: bool = False,
1520-
) -> List[_DocumentOut]:
1520+
) -> List[Dict[str, Any]]:
15211521
"""Unpack a response from the database and decode the BSON document(s).
15221522
15231523
Check the response for errors and unpack, returning a dictionary
@@ -1541,7 +1541,7 @@ def unpack_response(
15411541
return bson.decode_all(self.documents, codec_options)
15421542
return bson._decode_all_selective(self.documents, codec_options, user_fields)
15431543

1544-
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
1544+
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
15451545
"""Unpack a command response."""
15461546
docs = self.unpack_response(codec_options=codec_options)
15471547
assert self.number_returned == 1
@@ -1604,7 +1604,7 @@ def unpack_response(
16041604
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
16051605
user_fields: Optional[Mapping[str, Any]] = None,
16061606
legacy_response: bool = False,
1607-
) -> List[_DocumentOut]:
1607+
) -> List[Dict[str, Any]]:
16081608
"""Unpack a OP_MSG command response.
16091609
16101610
:Parameters:
@@ -1619,7 +1619,7 @@ def unpack_response(
16191619
assert not legacy_response
16201620
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
16211621

1622-
def command_response(self, codec_options: CodecOptions) -> Mapping[str, Any]:
1622+
def command_response(self, codec_options: CodecOptions) -> Dict[str, Any]:
16231623
"""Unpack a command response."""
16241624
return self.unpack_response(codec_options=codec_options)[0]
16251625

pymongo/mongo_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
import sys
117117
from types import TracebackType
118118

119+
from bson.objectid import ObjectId
119120
from pymongo.bulk import _Bulk
120121
from pymongo.client_session import ClientSession, _ServerSession
121122
from pymongo.cursor import _ConnectionManager
@@ -1898,7 +1899,9 @@ def _tmp_session(
18981899
else:
18991900
yield None
19001901

1901-
def _send_cluster_time(self, command: MutableMapping[str, Any], session: ClientSession) -> None:
1902+
def _send_cluster_time(
1903+
self, command: MutableMapping[str, Any], session: Optional[ClientSession]
1904+
) -> None:
19021905
topology_time = self._topology.max_cluster_time()
19031906
session_time = session.cluster_time if session else None
19041907
if topology_time and session_time:
@@ -2255,7 +2258,7 @@ def __init__(self, client: MongoClient, server: Server, session: Optional[Client
22552258
# of the pool at the time the connection attempt was started."
22562259
self.sock_generation = server.pool.gen.get_overall()
22572260
self.completed_handshake = False
2258-
self.service_id = None
2261+
self.service_id: Optional[ObjectId] = None
22592262
self.handled = False
22602263

22612264
def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None:

pymongo/monitor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pymongo.srv_resolver import _SrvResolver
3333

3434
if TYPE_CHECKING:
35-
from pymongo.pool import Connection, Pool
35+
from pymongo.pool import Connection, Pool, _CancellationContext
3636
from pymongo.settings import TopologySettings
3737
from pymongo.topology import Topology
3838

@@ -131,9 +131,8 @@ def __init__(
131131
self._pool = pool
132132
self._settings = topology_settings
133133
self._listeners = self._settings._pool_options._event_listeners
134-
pub = self._listeners is not None
135-
self._publish = pub and self._listeners.enabled_for_server_heartbeat
136-
self._cancel_context = None
134+
self._publish = self._listeners is not None and self._listeners.enabled_for_server_heartbeat
135+
self._cancel_context: Optional[_CancellationContext] = None
137136
self._rtt_monitor = _RttMonitor(
138137
topology,
139138
topology_settings,
@@ -238,7 +237,8 @@ def _check_server(self) -> ServerDescription:
238237
address = sd.address
239238
duration = time.monotonic() - start
240239
if self._publish:
241-
awaited = sd.is_server_type_known and sd.topology_version
240+
awaited = bool(sd.is_server_type_known and sd.topology_version)
241+
assert self._listeners is not None
242242
self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited)
243243
self._reset_connection()
244244
if isinstance(error, _OperationCancelled):
@@ -254,6 +254,7 @@ def _check_once(self) -> ServerDescription:
254254
"""
255255
address = self._server_description.address
256256
if self._publish:
257+
assert self._listeners is not None
257258
self._listeners.publish_server_heartbeat_started(address)
258259

259260
if self._cancel_context and self._cancel_context.cancelled:
@@ -267,6 +268,7 @@ def _check_once(self) -> ServerDescription:
267268
avg_rtt, min_rtt = self._rtt_monitor.get()
268269
sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt)
269270
if self._publish:
271+
assert self._listeners is not None
270272
self._listeners.publish_server_heartbeat_succeeded(
271273
address, round_trip_time, response, response.awaitable
272274
)

pymongo/network.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,13 @@
4747
if TYPE_CHECKING:
4848
from bson import CodecOptions
4949
from pymongo.client_session import ClientSession
50-
from pymongo.collation import Collation
5150
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
5251
from pymongo.mongo_client import MongoClient
5352
from pymongo.monitoring import _EventListeners
5453
from pymongo.pool import Connection
5554
from pymongo.read_concern import ReadConcern
5655
from pymongo.read_preferences import _ServerMode
57-
from pymongo.typings import _Address, _DocumentOut, _DocumentType
56+
from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType
5857
from pymongo.write_concern import WriteConcern
5958

6059
_UNPACK_HEADER = struct.Struct("<iiii").unpack
@@ -65,7 +64,7 @@ def command(
6564
dbname: str,
6665
spec: MutableMapping[str, Any],
6766
is_mongos: bool,
68-
read_preference: _ServerMode,
67+
read_preference: Optional[_ServerMode],
6968
codec_options: CodecOptions[_DocumentType],
7069
session: Optional[ClientSession],
7170
client: Optional[MongoClient],
@@ -76,7 +75,7 @@ def command(
7675
max_bson_size: Optional[int] = None,
7776
read_concern: Optional[ReadConcern] = None,
7877
parse_write_concern_error: bool = False,
79-
collation: Optional[Collation] = None,
78+
collation: Optional[_CollationIn] = None,
8079
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
8180
use_op_msg: bool = False,
8281
unacknowledged: bool = False,
@@ -119,6 +118,7 @@ def command(
119118
# Publish the original command document, perhaps with lsid and $clusterTime.
120119
orig = spec
121120
if is_mongos and not use_op_msg:
121+
assert read_preference is not None
122122
spec = message._maybe_add_read_preference(spec, read_preference)
123123
if read_concern and not (session and session.in_transaction):
124124
if read_concern.level:
@@ -232,7 +232,7 @@ def command(
232232

233233

234234
def receive_message(
235-
conn: Connection, request_id: int, max_message_size: int = MAX_MESSAGE_SIZE
235+
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
236236
) -> Union[_OpReply, _OpMsg]:
237237
"""Receive a raw BSON message or raise socket.error."""
238238
if _csot.get_timeout():

0 commit comments

Comments
 (0)