diff --git a/bson/_cbsonmodule.c b/bson/_cbsonmodule.c index d91c7e0536..672f5eeda5 100644 --- a/bson/_cbsonmodule.c +++ b/bson/_cbsonmodule.c @@ -1644,6 +1644,56 @@ static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw_str) { return bytes_written; } + +/* Update Invalid Document error message to include doc. + */ +void handle_invalid_doc_error(PyObject* dict) { + PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; + PyObject *msg = NULL, *dict_str = NULL, *new_msg = NULL; + PyErr_Fetch(&etype, &evalue, &etrace); + PyObject *InvalidDocument = _error("InvalidDocument"); + if (InvalidDocument == NULL) { + goto cleanup; + } + + if (evalue && PyErr_GivenExceptionMatches(etype, InvalidDocument)) { + PyObject *msg = PyObject_Str(evalue); + if (msg) { + // Prepend doc to the existing message + PyObject *dict_str = PyObject_Str(dict); + if (dict_str == NULL) { + goto cleanup; + } + const char * dict_str_utf8 = PyUnicode_AsUTF8(dict_str); + if (dict_str_utf8 == NULL) { + goto cleanup; + } + const char * msg_utf8 = PyUnicode_AsUTF8(msg); + if (msg_utf8 == NULL) { + goto cleanup; + } + PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8); + Py_DECREF(evalue); + Py_DECREF(etype); + etype = InvalidDocument; + InvalidDocument = NULL; + if (new_msg) { + evalue = new_msg; + } else { + evalue = msg; + } + } + PyErr_NormalizeException(&etype, &evalue, &etrace); + } +cleanup: + PyErr_Restore(etype, evalue, etrace); + Py_XDECREF(msg); + Py_XDECREF(InvalidDocument); + Py_XDECREF(dict_str); + Py_XDECREF(new_msg); +} + + /* returns the number of bytes written or 0 on failure */ int write_dict(PyObject* self, buffer_t buffer, PyObject* dict, unsigned char check_keys, @@ -1743,40 +1793,8 @@ int write_dict(PyObject* self, buffer_t buffer, while (PyDict_Next(dict, &pos, &key, &value)) { if (!decode_and_write_pair(self, buffer, key, value, check_keys, options, top_level)) { - if (PyErr_Occurred()) { - PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; - PyErr_Fetch(&etype, &evalue, &etrace); - PyObject *InvalidDocument = _error("InvalidDocument"); - - if (top_level && InvalidDocument && PyErr_GivenExceptionMatches(etype, InvalidDocument)) { - - Py_DECREF(etype); - etype = InvalidDocument; - - if (evalue) { - PyObject *msg = PyObject_Str(evalue); - Py_DECREF(evalue); - - if (msg) { - // Prepend doc to the existing message - PyObject *dict_str = PyObject_Str(dict); - PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", PyUnicode_AsUTF8(dict_str), PyUnicode_AsUTF8(msg)); - Py_DECREF(dict_str); - - if (new_msg) { - evalue = new_msg; - } - else { - evalue = msg; - } - } - } - PyErr_NormalizeException(&etype, &evalue, &etrace); - } - else { - Py_DECREF(InvalidDocument); - } - PyErr_Restore(etype, evalue, etrace); + if (PyErr_Occurred() && top_level) { + handle_invalid_doc_error(dict); } return 0; } @@ -1796,6 +1814,9 @@ int write_dict(PyObject* self, buffer_t buffer, } if (!decode_and_write_pair(self, buffer, key, value, check_keys, options, top_level)) { + if (PyErr_Occurred() && top_level) { + handle_invalid_doc_error(dict); + } Py_DECREF(key); Py_DECREF(value); Py_DECREF(iter); diff --git a/test/test_bson.py b/test/test_bson.py index e550b538d3..e601be4915 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -1112,6 +1112,34 @@ def __repr__(self): with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"): encode(doc) + def test_doc_in_invalid_document_error_message_mapping(self): + class MyMapping(abc.Mapping): + def keys(): + return ["t"] + + def __getitem__(self, name): + if name == "_id": + return None + return Wrapper(name) + + def __len__(self): + return 1 + + def __iter__(self): + return iter(["t"]) + + class Wrapper: + def __init__(self, val): + self.val = val + + def __repr__(self): + return repr(self.val) + + self.assertEqual("1", repr(Wrapper(1))) + doc = MyMapping() + with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"): + encode(doc) + class TestCodecOptions(unittest.TestCase): def test_document_class(self):