diff --git a/bson/__init__.py b/bson/__init__.py index 972b184015..f61a7ddf75 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1406,7 +1406,7 @@ def has_c() -> bool: return _USE_C -def _after_fork(): +def _after_fork() -> None: """Releases the ObjectID lock child.""" if ObjectId._inc_lock.locked(): ObjectId._inc_lock.release() diff --git a/bson/binary.py b/bson/binary.py index 0727bd17c0..f8a475f8a3 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -365,5 +365,5 @@ def __hash__(self) -> int: def __ne__(self, other: Any) -> bool: return not self == other - def __repr__(self): + def __repr__(self) -> str: return f"Binary({bytes.__repr__(self)}, {self.__subtype})" diff --git a/bson/code.py b/bson/code.py index 37b2aa85d4..26bed0103d 100644 --- a/bson/code.py +++ b/bson/code.py @@ -86,7 +86,7 @@ def scope(self) -> Optional[Mapping[str, Any]]: """Scope dictionary for this instance or ``None``.""" return self.__scope - def __repr__(self): + def __repr__(self) -> str: return f"Code({str.__repr__(self)}, {self.__scope!r})" def __eq__(self, other: Any) -> bool: diff --git a/bson/codec_options.py b/bson/codec_options.py index f146898dfd..9c511b5d6f 100644 --- a/bson/codec_options.py +++ b/bson/codec_options.py @@ -172,7 +172,7 @@ def _validate_type_encoder(self, codec: _Codec) -> None: ) raise TypeError(err_msg) - def __repr__(self): + def __repr__(self) -> str: return "{}(type_codecs={!r}, fallback_encoder={!r})".format( self.__class__.__name__, self.__type_codecs, @@ -465,7 +465,7 @@ def _options_dict(self) -> Dict[str, Any]: "datetime_conversion": self.datetime_conversion, } - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self._arguments_repr()})" def with_options(self, **kwargs: Any) -> "CodecOptions": diff --git a/bson/datetime_ms.py b/bson/datetime_ms.py index c422d6e379..6b9472b5b6 100644 --- a/bson/datetime_ms.py +++ b/bson/datetime_ms.py @@ -111,12 +111,12 @@ def __int__(self) -> int: # Timezones are hashed by their offset, which is a timedelta # and therefore there are more than 24 possible timezones. @functools.lru_cache(maxsize=None) -def _min_datetime_ms(tz=datetime.timezone.utc): +def _min_datetime_ms(tz: datetime.timezone = datetime.timezone.utc) -> int: return _datetime_to_millis(datetime.datetime.min.replace(tzinfo=tz)) @functools.lru_cache(maxsize=None) -def _max_datetime_ms(tz=datetime.timezone.utc): +def _max_datetime_ms(tz: datetime.timezone = datetime.timezone.utc) -> int: return _datetime_to_millis(datetime.datetime.max.replace(tzinfo=tz)) diff --git a/bson/dbref.py b/bson/dbref.py index 1141e7ba12..1bd4cadcc0 100644 --- a/bson/dbref.py +++ b/bson/dbref.py @@ -101,7 +101,7 @@ def as_doc(self) -> SON[str, Any]: doc.update(self.__kwargs) return doc - def __repr__(self): + def __repr__(self) -> str: extra = "".join([f", {k}={v!r}" for k, v in self.__kwargs.items()]) if self.database is None: return f"DBRef({self.collection!r}, {self.id!r}{extra})" diff --git a/bson/decimal128.py b/bson/decimal128.py index 0e24b5bbae..fd39e94705 100644 --- a/bson/decimal128.py +++ b/bson/decimal128.py @@ -296,7 +296,7 @@ def __str__(self) -> str: return "NaN" return str(dec) - def __repr__(self): + def __repr__(self) -> str: return f"Decimal128('{str(self)}')" def __setstate__(self, value: Tuple[int, int]) -> None: diff --git a/bson/json_util.py b/bson/json_util.py index bc566fa982..82604f382f 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -224,7 +224,7 @@ class JSONOptions(CodecOptions): datetime_representation: int strict_uuid: bool - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): """Encapsulates JSON options for :func:`dumps` and :func:`loads`. :Parameters: diff --git a/bson/max_key.py b/bson/max_key.py index eb5705d378..83278087b0 100644 --- a/bson/max_key.py +++ b/bson/max_key.py @@ -50,5 +50,5 @@ def __ge__(self, dummy: Any) -> bool: def __gt__(self, other: Any) -> bool: return not isinstance(other, MaxKey) - def __repr__(self): + def __repr__(self) -> str: return "MaxKey()" diff --git a/bson/min_key.py b/bson/min_key.py index 2c8f73d560..50011df6e5 100644 --- a/bson/min_key.py +++ b/bson/min_key.py @@ -50,5 +50,5 @@ def __ge__(self, other: Any) -> bool: def __gt__(self, dummy: Any) -> bool: return False - def __repr__(self): + def __repr__(self) -> str: return "MinKey()" diff --git a/bson/objectid.py b/bson/objectid.py index 966fd9f94b..d3afe3cd3c 100644 --- a/bson/objectid.py +++ b/bson/objectid.py @@ -243,7 +243,7 @@ def __setstate__(self, value: Any) -> None: def __str__(self) -> str: return binascii.hexlify(self.__id).decode() - def __repr__(self): + def __repr__(self) -> str: return f"ObjectId('{str(self)}')" def __eq__(self, other: Any) -> bool: diff --git a/bson/raw_bson.py b/bson/raw_bson.py index f48016909c..d5dbe8fbf9 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -173,7 +173,7 @@ def __eq__(self, other: Any) -> bool: return self.__raw == other.raw return NotImplemented - def __repr__(self): + def __repr__(self) -> str: return "{}({!r}, codec_options={!r})".format( self.__class__.__name__, self.raw, diff --git a/bson/regex.py b/bson/regex.py index c06e493f38..fe852fdfce 100644 --- a/bson/regex.py +++ b/bson/regex.py @@ -115,7 +115,7 @@ def __eq__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool: return not self == other - def __repr__(self): + def __repr__(self) -> str: return f"Regex({self.pattern!r}, {self.flags!r})" def try_compile(self) -> "Pattern[_T]": diff --git a/bson/son.py b/bson/son.py index 482e8d2584..7be749ceca 100644 --- a/bson/son.py +++ b/bson/son.py @@ -71,7 +71,7 @@ def __new__(cls: Type["SON[_Key, _Value]"], *args: Any, **kwargs: Any) -> "SON[_ instance.__keys = [] return instance - def __repr__(self): + def __repr__(self) -> str: result = [] for key in self.__keys: result.append(f"({key!r}, {self[key]!r})") diff --git a/bson/timestamp.py b/bson/timestamp.py index 5591b60e41..168f2824df 100644 --- a/bson/timestamp.py +++ b/bson/timestamp.py @@ -111,7 +111,7 @@ def __ge__(self, other: Any) -> bool: return (self.time, self.inc) >= (other.time, other.inc) return NotImplemented - def __repr__(self): + def __repr__(self) -> str: return f"Timestamp({self.__time}, {self.__inc})" def as_datetime(self) -> datetime.datetime: diff --git a/pymongo/_csot.py b/pymongo/_csot.py index 7a5a8a7302..9ad477a249 100644 --- a/pymongo/_csot.py +++ b/pymongo/_csot.py @@ -14,6 +14,8 @@ """Internal helpers for CSOT.""" +from __future__ import annotations + import functools import time from collections import deque @@ -72,7 +74,7 @@ def __init__(self, timeout: Optional[float]): self._timeout = timeout self._tokens: Optional[Tuple[Token, Token, Token]] = None - def __enter__(self): + def __enter__(self) -> _TimeoutContext: timeout_token = TIMEOUT.set(self._timeout) prev_deadline = DEADLINE.get() next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") @@ -81,7 +83,7 @@ def __enter__(self): self._tokens = (timeout_token, deadline_token, rtt_token) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._tokens: timeout_token, deadline_token, rtt_token = self._tokens TIMEOUT.reset(timeout_token) @@ -97,7 +99,7 @@ def apply(func: F) -> F: """Apply the client's timeoutMS to this operation.""" @functools.wraps(func) - def csot_wrapper(self, *args, **kwargs): + def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> F: if get_timeout() is None: timeout = self._timeout if timeout is not None: diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index 2f7dcb857f..a327016d73 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -23,7 +23,7 @@ except ImportError: class AwsSaslContext: # type: ignore - def __init__(self, credentials): + def __init__(self, credentials: MongoCredential): pass _HAVE_MONGODB_AWS = False @@ -35,11 +35,11 @@ def __init__(self, credentials): set_use_cached_credentials(True) except ImportError: - def set_cached_credentials(creds): + def set_cached_credentials(creds: Optional[AwsCredential]) -> None: pass -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping, Optional, Type import bson from bson.binary import Binary @@ -54,7 +54,7 @@ def set_cached_credentials(creds): class _AwsSaslContext(AwsSaslContext): # type: ignore # Dependency injection: - def binary_type(self): + def binary_type(self) -> Type[Binary]: """Return the bson.binary.Binary type.""" return Binary diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index a40b5b2f14..ed3031ef52 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -257,7 +257,7 @@ def _run_aggregation_cmd( cmd.get_cursor, self._target._read_preference_for(session), session ) - def _create_cursor(self): + def _create_cursor(self) -> CommandCursor: with self._client._tmp_session(self._session, close=False) as s: return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None) diff --git a/pymongo/collation.py b/pymongo/collation.py index bdc996be1b..bada2d9417 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -199,7 +199,7 @@ def document(self) -> Dict[str, Any]: """ return self.__document.copy() - def __repr__(self): + def __repr__(self) -> str: document = self.document return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) diff --git a/pymongo/collection.py b/pymongo/collection.py index fbbe7fb593..6b2b16db77 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -2481,7 +2481,7 @@ def create_search_indexes( if comment is not None: kwargs["comment"] = comment - def gen_indexes(): + def gen_indexes() -> Iterator[Mapping[str, Any]]: for index in models: if not isinstance(index, SearchIndexModel): raise TypeError( diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 248dfb17bd..70c386ab04 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -179,7 +179,7 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener): def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: logging.info(f"[pool {event.address}] pool created") - def pool_ready(self, event): + def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: logging.info(f"[pool {event.address}] pool ready") def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: diff --git a/pymongo/ocsp_cache.py b/pymongo/ocsp_cache.py index b0ac4d654f..033a7b607a 100644 --- a/pymongo/ocsp_cache.py +++ b/pymongo/ocsp_cache.py @@ -19,7 +19,7 @@ from collections import namedtuple from datetime import datetime as _datetime from datetime import timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict from pymongo.lock import _create_lock @@ -35,8 +35,8 @@ class _OCSPCache: ["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"], ) - def __init__(self): - self._data = {} + def __init__(self) -> None: + self._data: Dict[Any, OCSPResponse] = {} # Hold this lock when accessing _data. self._lock = _create_lock() @@ -77,7 +77,10 @@ def __setitem__(self, key: OCSPRequest, value: OCSPResponse) -> None: # Cache new response OR update cached response if new response # has longer validity. cached_value = self._data.get(cache_key, None) - if cached_value is None or cached_value.next_update < value.next_update: + if cached_value is None or ( + cached_value.next_update is not None + and cached_value.next_update < value.next_update + ): self._data[cache_key] = value def __getitem__(self, item: OCSPRequest) -> OCSPResponse: @@ -92,6 +95,8 @@ def __getitem__(self, item: OCSPRequest) -> OCSPResponse: value = self._data[cache_key] # Return cached response if it is still valid. + assert value.this_update is not None + assert value.next_update is not None if ( value.this_update <= _datetime.now(tz=timezone.utc).replace(tzinfo=None) diff --git a/pymongo/topology.py b/pymongo/topology.py index 6fd1138fb2..5b4197bc16 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -890,7 +890,7 @@ def __repr__(self) -> str: msg = "CLOSED " return f"<{self.__class__.__name__} {msg}{self._description!r}>" - def eq_props(self): + def eq_props(self) -> Tuple[Tuple[_Address, ...], Optional[str], Optional[str], str]: """The properties to use for MongoClient/Topology equality checks.""" ts = self._settings return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name) diff --git a/tools/ocsptest.py b/tools/ocsptest.py index bba84252be..702f15ee99 100644 --- a/tools/ocsptest.py +++ b/tools/ocsptest.py @@ -25,7 +25,7 @@ logging.basicConfig(format=FORMAT, level=logging.DEBUG) -def check_ocsp(host, port, capath): +def check_ocsp(host: str, port: int, capath: str) -> None: ctx = get_ssl_context( None, # certfile None, # passphrase @@ -47,7 +47,7 @@ def check_ocsp(host, port, capath): s.close() -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Debug OCSP") parser.add_argument("--host", type=str, required=True, help="Host to connect to") parser.add_argument("-p", "--port", type=int, default=443, help="Port to connect to") diff --git a/tox.ini b/tox.ini index c268fd1f4c..11913a8832 100644 --- a/tox.ini +++ b/tox.ini @@ -86,7 +86,7 @@ deps = certifi; platform_system == "win32" or platform_system == "Darwin" typing_extensions commands = - mypy --install-types --non-interactive bson gridfs tools pymongo + mypy --install-types --non-interactive --disallow-untyped-defs bson gridfs tools pymongo mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --allow-redefinition --allow-untyped-globals --exclude "test/mypy_fails/*.*" --exclude "test/conftest.py" test mypy --install-types --non-interactive test/test_typing.py test/test_typing_strict.py