Skip to content

PYTHON-4796 Update type checkers and handle with_options typing #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def decode_iter(
elements = data[position : position + obj_size]
position += obj_size

yield _bson_to_dict(elements, opts) # type:ignore[misc, type-var]
yield _bson_to_dict(elements, opts) # type:ignore[misc]


@overload
Expand Down Expand Up @@ -1370,7 +1370,7 @@ def decode_file_iter(
raise InvalidBSON("cut off in middle of objsize")
obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4
elements = size_data + file_obj.read(max(0, obj_size))
yield _bson_to_dict(elements, opts) # type:ignore[type-var, arg-type, misc]
yield _bson_to_dict(elements, opts) # type:ignore[arg-type, misc]


def is_valid(bson: bytes) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion bson/decimal128.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(self, value: _VALUE_OPTIONS) -> None:
"from list or tuple. Must have exactly 2 "
"elements."
)
self.__high, self.__low = value # type: ignore
self.__high, self.__low = value
else:
raise TypeError(f"Cannot convert {value!r} to Decimal128")

Expand Down
2 changes: 1 addition & 1 deletion bson/json_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __new__(
"JSONOptions.datetime_representation must be one of LEGACY, "
"NUMBERLONG, or ISO8601 from DatetimeRepresentation."
)
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs)) # type:ignore[arg-type]
self = cast(JSONOptions, super().__new__(cls, *args, **kwargs))
if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL):
raise ValueError(
"JSONOptions.json_mode must be one of LEGACY, RELAXED, "
Expand Down
2 changes: 1 addition & 1 deletion bson/son.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.update(kwargs)

def __new__(cls: Type[SON[_Key, _Value]], *args: Any, **kwargs: Any) -> SON[_Key, _Value]:
instance = super().__new__(cls, *args, **kwargs) # type: ignore[type-var]
instance = super().__new__(cls, *args, **kwargs)
instance.__keys = []
return instance

Expand Down
5 changes: 3 additions & 2 deletions hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ features = ["docs","test"]
test = "sphinx-build -E -b doctest doc ./doc/_build/doctest"

[envs.typing]
features = ["encryption", "ocsp", "zstd", "aws"]
dependencies = ["mypy==1.2.0","pyright==1.1.290", "certifi", "typing_extensions"]
pre-install-commands = [
"pip install -q -r requirements/typing.txt",
]
[envs.typing.scripts]
check-mypy = [
"mypy --install-types --non-interactive bson gridfs tools pymongo",
Expand Down
3 changes: 1 addition & 2 deletions pymongo/_csot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,13 @@ def __init__(self, timeout: Optional[float]):
self._timeout = timeout
self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None

def __enter__(self) -> _TimeoutContext:
def __enter__(self) -> None:
timeout_token = TIMEOUT.set(self._timeout)
prev_deadline = DEADLINE.get()
next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf")
deadline_token = DEADLINE.set(min(prev_deadline, next_deadline))
rtt_token = RTT.set(0.0)
self._tokens = (timeout_token, deadline_token, rtt_token)
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._tokens:
Expand Down
23 changes: 22 additions & 1 deletion pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TypeVar,
Union,
cast,
overload,
)

from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
Expand Down Expand Up @@ -332,13 +333,33 @@ def database(self) -> AsyncDatabase[_DocumentType]:
"""
return self._database

@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncCollection[_DocumentType]:
...

@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncCollection[_DocumentTypeArg]:
...

def with_options(
self,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> AsyncCollection[_DocumentType]:
) -> AsyncCollection[_DocumentType] | AsyncCollection[_DocumentTypeArg]:
"""Get a clone of this collection changing the specified settings.

>>> coll1.read_preference
Expand Down
22 changes: 21 additions & 1 deletion pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,33 @@ def name(self) -> str:
"""The name of this :class:`AsyncDatabase`."""
return self._name

@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncDatabase[_DocumentType]:
...

@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> AsyncDatabase[_DocumentTypeArg]:
...

def with_options(
self,
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> AsyncDatabase[_DocumentType]:
) -> AsyncDatabase[_DocumentType] | AsyncDatabase[_DocumentTypeArg]:
"""Get a clone of this database changing the specified settings.

>>> db1.read_preference
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ async def _configured_socket(
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
Expand Down
2 changes: 1 addition & 1 deletion pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def get_normed_key(x: str) -> str:
return x

def get_setter_key(x: str) -> str:
return options.cased_key(x) # type: ignore[attr-defined]
return options.cased_key(x)

else:
validated_options = {}
Expand Down
2 changes: 1 addition & 1 deletion pymongo/compression_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def _have_snappy() -> bool:
try:
import snappy # type:ignore[import] # noqa: F401
import snappy # type:ignore[import-not-found] # noqa: F401

return True
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion pymongo/encryption_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Mapping, Optional

try:
import pymongocrypt # type:ignore[import] # noqa: F401
import pymongocrypt # type:ignore[import-untyped] # noqa: F401

# Check for pymongocrypt>=1.10.
from pymongocrypt import synchronous as _ # noqa: F401
Expand Down
23 changes: 22 additions & 1 deletion pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TypeVar,
Union,
cast,
overload,
)

from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions
Expand Down Expand Up @@ -333,13 +334,33 @@ def database(self) -> Database[_DocumentType]:
"""
return self._database

@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Collection[_DocumentType]:
...

@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Collection[_DocumentTypeArg]:
...

def with_options(
self,
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> Collection[_DocumentType]:
) -> Collection[_DocumentType] | Collection[_DocumentTypeArg]:
"""Get a clone of this collection changing the specified settings.

>>> coll1.read_preference
Expand Down
22 changes: 21 additions & 1 deletion pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,33 @@ def name(self) -> str:
"""The name of this :class:`Database`."""
return self._name

@overload
def with_options(
self,
codec_options: None = None,
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Database[_DocumentType]:
...

@overload
def with_options(
self,
codec_options: bson.CodecOptions[_DocumentTypeArg],
read_preference: Optional[_ServerMode] = ...,
write_concern: Optional[WriteConcern] = ...,
read_concern: Optional[ReadConcern] = ...,
) -> Database[_DocumentTypeArg]:
...

def with_options(
self,
codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None,
read_preference: Optional[_ServerMode] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
) -> Database[_DocumentType]:
) -> Database[_DocumentType] | Database[_DocumentTypeArg]:
"""Get a clone of this database changing the specified settings.

>>> db1.read_preference
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.
and not options.tls_allow_invalid_hostnames
):
try:
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host)
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined]
except _CertificateError:
ssl_sock.close()
raise
Expand Down
7 changes: 7 additions & 0 deletions requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mypy==1.11.2
pyright==1.1.382.post1
typing_extensions
-r ./encryption.txt
-r ./ocsp.txt
-r ./zstd.txt
-r ./aws.txt
2 changes: 1 addition & 1 deletion test/asynchronous/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def test_with_options(self):
"write_concern": WriteConcern(w=1),
"read_concern": ReadConcern(level="local"),
}
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
for opt in newopts:
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))

Expand Down
2 changes: 1 addition & 1 deletion test/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def test_with_options(self):
"write_concern": WriteConcern(w=1),
"read_concern": ReadConcern(level="local"),
}
db2 = db1.with_options(**newopts) # type: ignore[arg-type]
db2 = db1.with_options(**newopts) # type: ignore[arg-type, call-overload]
for opt in newopts:
self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt)))

Expand Down
34 changes: 22 additions & 12 deletions test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
cast,
)

try:
if TYPE_CHECKING:
from typing_extensions import NotRequired, TypedDict

from bson import ObjectId
Expand All @@ -49,16 +49,13 @@ class MovieWithId(TypedDict):
year: int

class ImplicitMovie(TypedDict):
_id: NotRequired[ObjectId] # pyright: ignore[reportGeneralTypeIssues]
_id: NotRequired[ObjectId]
name: str
year: int

except ImportError:
Movie = dict # type:ignore[misc,assignment]
ImplicitMovie = dict # type: ignore[assignment,misc]
MovieWithId = dict # type: ignore[assignment,misc]
TypedDict = None
NotRequired = None # type: ignore[assignment]
else:
Movie = dict
ImplicitMovie = dict
NotRequired = None


try:
Expand Down Expand Up @@ -234,6 +231,19 @@ def execute_transaction(session):
execute_transaction, read_preference=ReadPreference.PRIMARY
)

def test_with_options(self) -> None:
coll: Collection[Dict[str, Any]] = self.coll
coll.drop()
doc = {"name": "foo", "year": 1982, "other": 1}
coll.insert_one(doc)

coll2 = coll.with_options(codec_options=CodecOptions(document_class=Movie))
retrieved = coll2.find_one()
assert retrieved is not None
assert retrieved["name"] == "foo"
# We expect a type error here.
assert retrieved["other"] == 1 # type:ignore[typeddict-item]


class TestDecode(unittest.TestCase):
def test_bson_decode(self) -> None:
Expand Down Expand Up @@ -426,7 +436,7 @@ def test_bulk_write_document_type_insertion(self):
)
coll.bulk_write(
[
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971})
InsertOne({"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
] # No error because it is in-line.
)

Expand All @@ -443,7 +453,7 @@ def test_bulk_write_document_type_replacement(self):
)
coll.bulk_write(
[
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971})
ReplaceOne({}, {"_id": ObjectId(), "name": "THX-1138", "year": 1971}) # pyright: ignore
] # No error because it is in-line.
)

Expand Down Expand Up @@ -566,7 +576,7 @@ def test_explicit_document_type(self) -> None:
def test_typeddict_document_type(self) -> None:
options: CodecOptions[Movie] = CodecOptions()
# Suppress: Cannot instantiate type "Type[Movie]".
obj = options.document_class(name="a", year=1) # type: ignore[misc]
obj = options.document_class(name="a", year=1)
assert obj["year"] == 1
assert obj["name"] == "a"

Expand Down
2 changes: 1 addition & 1 deletion tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from os import listdir
from pathlib import Path

from unasync import Rule, unasync_files # type: ignore[import]
from unasync import Rule, unasync_files # type: ignore[import-not-found]

replacements = {
"AsyncCollection": "Collection",
Expand Down
Loading