diff --git a/DEVELOPER.md b/DEVELOPER.md index 8a50c7bf9..e886fa63a 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -31,6 +31,15 @@ or: Documentation will be generated in `build/sphinx/html`. +## Unasync -- maintaining sync versions of async code + + $ python tools/unasync.py + + # Run the script with the --check flag to ensure the sync code is up to date + $ python tools/unasync.py --check + +If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counter parts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the --check flag and fail the build. + ## Tests diff --git a/LICENSE b/LICENSE index 521517282..02f2aa004 100644 --- a/LICENSE +++ b/LICENSE @@ -652,3 +652,36 @@ For the files wingetopt.c wingetopt.h downloaded from https://github.com/alex85k */ + +LICENSE.unasync +-------------------------------------------------------------- +For unasync code in setup.py, derived from +https://github.com/encode/httpcore/blob/ae46dfbd4330eefaa9cd6ab1560dec18a1d0bcb8/scripts/unasync.py + +Copyright © 2020, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f9df81c65..129373f63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,3 +73,6 @@ optional-dependencies.all = { file = [ "requirements/requirements-avro.txt", "requirements/requirements-json.txt", "requirements/requirements-protobuf.txt"] } + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index 84e6818ca..6ed6d8f34 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -9,3 +9,5 @@ requests-mock respx pytest_cov pluggy<1.6.0 +pytest-asyncio +async-timeout diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 2d81f44ac..c19e0f46f 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -31,6 +31,7 @@ RuleSet, Schema, SchemaRegistryClient, + AsyncSchemaRegistryClient, SchemaRegistryError, SchemaReference, ServerConfig @@ -57,6 +58,7 @@ "RuleSet", "Schema", "SchemaRegistryClient", + "AsyncSchemaRegistryClient", "SchemaRegistryError", "SchemaReference", "ServerConfig", diff --git a/src/confluent_kafka/schema_registry/_async/__init__.py b/src/confluent_kafka/schema_registry/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py new file mode 100644 index 000000000..c7a523fbe --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -0,0 +1,603 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from json import loads +from typing import Dict, Union, Optional, Callable + +from fastavro import schemaless_reader, schemaless_writer +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ + get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO, AVRO_TYPE + +from confluent_kafka.schema_registry import (Schema, + topic_subject_name_strategy, + RuleMode, + AsyncSchemaRegistryClient, + prefix_schema_id_serializer, + dual_schema_id_deserializer) +from confluent_kafka.serialization import (SerializationError, + SerializationContext) +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache, SchemaId + + +__all__ = [ + '_resolve_named_schema', + 'AsyncAvroSerializer', + 'AsyncAvroDeserializer', +] + + +async def _resolve_named_schema( + schema: Schema, schema_registry_client: AsyncSchemaRegistryClient +) -> Dict[str, AvroSchema]: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :return: named_schemas dict. + """ + named_schemas = {} + if schema.references is not None: + for ref in schema.references: + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) + parsed_schema = parse_schema_with_repo( + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + named_schemas.update(ref_named_schemas) + named_schemas[ref.name] = parsed_schema + return named_schemas + + +@asyncinit +class AsyncAvroSerializer(AsyncBaseSerializer): + """ + Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+==================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-----------------------------+----------+--------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Note: + Prior to serialization, all values must first be converted to + a dict instance. This may handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with AvroSerializer. + + See ``avro_producer.py`` in the examples directory for example usage. + + Note: + Tuple notation can be used to determine which branch of an ambiguous union to take. + + See `fastavro notation `_ + + Args: + schema_registry_client (SchemaRegistryClient): Schema Registry client instance. + + schema_str (str or Schema): + Avro `Schema Declaration. `_ + Accepts either a string or a :py:class:`Schema` instance. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. + + conf (dict): AvroSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_schema', + '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer} + + async def __init__( + self, + schema_registry_client: AsyncSchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + schema = None + + self._registry = schema_registry_client + self._schema_id = None + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + parsed_schema = await self._get_parsed_schema(schema) + + if isinstance(parsed_schema, list): + # if parsed_schema is a list, we have an Avro union and there + # is no valid schema name. This is fine because the only use of + # schema_name is for supplying the subject name to the registry + # and union types should use topic_subject_name_strategy, which + # just discards the schema name anyway + schema_name = None + else: + # The Avro spec states primitives have a name equal to their type + # i.e. {"type": "string"} has a name of string. + # This function does not comply. + # https://github.com/fastavro/fastavro/issues/415 + schema_dict = loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + parsed_schema = None + + self._schema = schema + self._schema_name = schema_name + self._parsed_schema = parsed_schema + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to Avro binary format, prepending it with Confluent + Schema Registry framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata pertaining to the serialization operation. + + Raises: + SerializerError: If any error occurs serializing obj. + SchemaRegistryError: If there was an error registering the schema with + Schema Registry, or auto.register.schemas is + false and the schema was not registered. + + Returns: + bytes: Confluent Schema Registry encoded Avro bytes + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = await self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + registered_schema = await self._registry.register_schema_full_response( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(AVRO_TYPE, registered_schema.schema_id, registered_schema.guid) + else: + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(AVRO_TYPE, registered_schema.schema_id, registered_schema.guid) + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + parsed_schema = await self._get_parsed_schema(latest_schema.schema) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, get_inline_tags(parsed_schema), + field_transformer) + else: + parsed_schema = self._parsed_schema + + with _ContextStringIO() as fo: + # write the record to the rest of the buffer + schemaless_writer(fo, parsed_schema, value) + + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + + async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = await _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema + + +@asyncinit +class AsyncAvroDeserializer(AsyncBaseDeserializer): + """ + Deserializer for Avro binary encoded data with Confluent Schema Registry + framing. + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-----------------------------+----------+--------------------------------------------------+ + Note: + By default, Avro complex types are returned as dicts. This behavior can + be overridden by registering a callable ``from_dict`` with the deserializer to + convert the dicts to the desired type. + + See ``avro_consumer.py`` in the examples directory in the examples + directory for example usage. + + Args: + schema_registry_client (SchemaRegistryClient): Confluent Schema Registry + client instance. + + schema_str (str, Schema, optional): Avro reader schema declaration Accepts + either a string or a :py:class:`Schema` instance. If not provided, the + writer schema will be used as the reader schema. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to an instance of some object. + + return_record_name (bool): If True, when reading a union of records, the result will + be a tuple where the first value is the name of the record and the second value is + the record itself. Defaults to False. + + See Also: + `Apache Avro Schema Declaration `_ + + `Apache Avro Schema Resolution `_ + """ + + __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', + '_schema', '_parsed_schemas'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer} + + async def __init__( + self, + schema_registry_client: AsyncSchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + return_record_name: bool = False, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + schema = None + if schema_str is not None: + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + raise TypeError('You must pass either schema string or schema object') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema = await self._get_parsed_schema(self._schema) + else: + self._reader_schema = None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature " + "from_dict(SerializationContext, dict) -> object") + self._from_dict = from_dict + + self._return_record_name = return_record_name + if not isinstance(self._return_record_name, bool): + raise ValueError("return_record_name must be a boolean value") + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + return self.__deserialize(data, ctx) + + async def __deserialize( + self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + """ + Deserialize Avro binary encoded data with Confluent Schema Registry framing to + a dict, or object instance according to from_dict, if specified. + + Arguments: + data (bytes): bytes + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError: if an error occurs parsing data. + + Returns: + object: If data is None, then None. Else, a dict, or object instance according + to from_dict, if specified. + """ # noqa: E501 + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) if ctx else None + latest_schema = None + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + + schema_id = SchemaId(AVRO_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + + writer_schema_raw = await self._get_writer_schema(schema_id, subject) + writer_schema = await self._get_parsed_schema(writer_schema_raw) + + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema = self._reader_schema + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if migrations: + obj_dict = schemaless_reader(payload, + writer_schema, + None, + self._return_record_name) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + else: + obj_dict = schemaless_reader(payload, + writer_schema, + reader_schema, + self._return_record_name) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = await _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py new file mode 100644 index 000000000..c57c6b9a0 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -0,0 +1,659 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Union, Optional, Tuple, Callable + +from cachetools import LRUCache +from jsonschema import ValidationError +from jsonschema.protocols import Validator +from jsonschema.validators import validator_for +from referencing import Registry, Resource + +from confluent_kafka.schema_registry import (Schema, + topic_subject_name_strategy, + RuleMode, AsyncSchemaRegistryClient, + prefix_schema_id_serializer, + dual_schema_id_deserializer) +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.common.json_schema import ( + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE +) +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache, SchemaId +from confluent_kafka.serialization import (SerializationError, + SerializationContext) + +__all__ = [ + '_resolve_named_schema', + 'AsyncJSONSerializer', + 'AsyncJSONDeserializer' +] + + +async def _resolve_named_schema( + schema: Schema, schema_registry_client: AsyncSchemaRegistryClient, + ref_registry: Optional[Registry] = None +) -> Registry: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param ref_registry: Registry of named schemas resolved recursively. + :return: Registry + """ + if ref_registry is None: + # Retrieve external schemas for backward compatibility + ref_registry = Registry(retrieve=_retrieve_via_httpx) + if schema.references is not None: + for ref in schema.references: + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) + resource = Resource.from_contents( + referenced_schema_dict, default_specification=DEFAULT_SPEC) + ref_registry = ref_registry.with_resource(ref.name, resource) + return ref_registry + + +@asyncinit +class AsyncJSONSerializer(AsyncBaseSerializer): + """ + Serializer that outputs JSON encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-----------------------------+----------+----------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Notes: + The ``title`` annotation, referred to elsewhere as a record name + is not strictly required by the JSON Schema specification. It is + however required by this serializer in order to register the schema + with Confluent Schema Registry. + + Prior to serialization, all objects must first be converted to + a dict instance. This may be handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with JSONSerializer. + + Args: + schema_str (str, Schema): + `JSON Schema definition. `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For + referencing other schemas, use a :py:class:`Schema` instance. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. + Converts object to a dict. + + conf (dict): JsonSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', + '_schema', '_schema_id', '_schema_name', '_to_dict', + '_parsed_schemas', '_validators', '_validate', '_json_encode'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer, + 'validate': True} + + async def __init__( + self, + schema_str: Union[str, Schema, None], + schema_registry_client: AsyncSchemaRegistryClient, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_encode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + self._schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + self._schema = schema_str + else: + self._schema = None + + self._json_encode = json_encode or json.dumps + self._registry = schema_registry_client + self._rule_registry = ( + rule_registry if rule_registry else RuleRegistry.get_global_instance() + ) + self._schema_id = None + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + schema_dict, ref_registry = await self._get_parsed_schema(self._schema) + if schema_dict: + schema_name = schema_dict.get('title', None) + else: + schema_name = None + + self._schema_name = schema_name + self._parsed_schema = schema_dict + self._ref_registry = ref_registry + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to JSON, prepending it with Confluent Schema Registry + framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError if any error occurs serializing obj. + + Returns: + bytes: None if obj is None, else a byte array containing the JSON + serialized data with Confluent Schema Registry framing. + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = await self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + registered_schema = await self._registry.register_schema_full_response( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(JSON_TYPE, registered_schema.schema_id, registered_schema.guid) + else: + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(JSON_TYPE, registered_schema.schema_id, registered_schema.guid) + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + schema = latest_schema.schema + parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) + else: + schema = self._schema + parsed_schema, ref_registry = self._parsed_schema, self._ref_registry + + if self._validate: + try: + validator = self._get_validator(schema, parsed_schema, ref_registry) + validator.validate(value) + except ValidationError as ve: + raise SerializationError(ve.message) + + with _ContextStringIO() as fo: + # JSON dump always writes a str never bytes + # https://docs.python.org/3/library/json.html + encoded_value = self._json_encode(value) + if isinstance(encoded_value, str): + encoded_value = encoded_value.encode("utf8") + fo.write(encoded_value) + + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = await _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator + + +@asyncinit +class AsyncJSONDeserializer(AsyncBaseDeserializer): + """ + Deserializer for JSON encoded data with Confluent Schema Registry + framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-----------------------------+----------+----------------------------------------------------+ + + Args: + schema_str (str, Schema, optional): + `JSON schema definition `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. If not provided, schemas will be + retrieved from schema_registry_client based on the schema ID in the + wire header of each message. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to a Python object instance. + + schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. + """ # noqa: E501 + + __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', + '_parsed_schemas', '_validators', '_validate', '_json_decode'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer, + 'validate': True} + + async def __init__( + self, + schema_str: Union[str, Schema, None], + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + schema_registry_client: Optional[AsyncSchemaRegistryClient] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_decode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + schema = schema_str + if bool(schema.references) and schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") + elif schema_str is None: + if schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is not provided""" + ) + schema = schema_str + else: + raise TypeError('You must pass either str or Schema') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + self._json_decode = json_decode or json.loads + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._subject_name_func): + raise ValueError("schema.id.deserializer must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._validate, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) + else: + self._reader_schema, self._ref_registry = None, None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature" + " from_dict(dict, SerializationContext) -> object") + + self._from_dict = from_dict + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a JSON encoded record with Confluent Schema Registry framing to + a dict, or object instance according to from_dict if from_dict is specified. + + Args: + data (bytes): A JSON serialized record with Confluent Schema Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization operation. + + Returns: + A dict, or object instance according to from_dict if from_dict is specified. + + Raises: + SerializerError: If there was an error reading the Confluent framing data, or + if ``data`` was not successfully validated with the configured schema. + """ + + if data is None: + return None + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = await self._get_reader_schema(subject) + + schema_id = SchemaId(JSON_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + + if self._registry is not None: + writer_schema_raw = await self._get_writer_schema(schema_id, subject) + writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("title")) + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + else: + writer_schema_raw = None + writer_schema, writer_ref_registry = None, None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema, reader_ref_registry = writer_schema, writer_ref_registry + + if migrations: + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) + + if self._validate: + try: + validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) + validator.validate(obj_dict) + except ValidationError as ve: + raise SerializationError(ve.message) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = await _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py new file mode 100644 index 000000000..ff060883c --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020-2022 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import warnings +from typing import Set, List, Union, Optional, Tuple + +from google.protobuf import json_format, descriptor_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.descriptor import Descriptor, FileDescriptor +from google.protobuf.message import DecodeError, Message +from google.protobuf.message_factory import GetMessageClass + +from confluent_kafka.schema_registry import (reference_subject_name_strategy, + topic_subject_name_strategy, + prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ + _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry import (Schema, + SchemaReference, + RuleMode) +from confluent_kafka.serialization import SerializationError, \ + SerializationContext +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache, SchemaId + +__all__ = [ + '_resolve_named_schema', + 'AsyncProtobufSerializer', + 'AsyncProtobufDeserializer', +] + + +async def _resolve_named_schema( + schema: Schema, + schema_registry_client: AsyncSchemaRegistryClient, + pool: DescriptorPool, + visited: Optional[Set[str]] = None +): + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: AsyncSchemaRegistryClient to use for retrieval. + :param pool: DescriptorPool to add resolved schemas to. + :return: DescriptorPool + """ + if visited is None: + visited = set() + if schema.references is not None: + for ref in schema.references: + if _is_builtin(ref.name) or ref.name in visited: + continue + visited.add(ref.name) + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + pool.Add(file_descriptor_proto) + + +@asyncinit +class AsyncProtobufSerializer(AsyncBaseSerializer): + """ + Serializer for Protobuf Message derived classes. Serialization format is Protobuf, + with Confluent Schema Registry framing. + + Configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +=====================================+==========+======================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether or not to skip known types when resolving | + | ``skip.known.types`` | bool | schema dependencies. | + | | | | + | | | Defaults to True. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | + | | | references are constructed. | + | | | | + | | | Defaults to reference_subject_name_strategy | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | + | | | serialize message indexes without zig-zag encoding. | + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If the consumers of the topic being produced to are | + | | | using confluent-kafka-python <1.8 then this property | + | | | must be set to True until all old consumers have | + | | | have been upgraded. | + | | | | + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Args: + msg_type (Message): Protobuf Message type. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + conf (dict): ProtobufSerializer configuration. + + See Also: + `Protobuf API reference `_ + """ # noqa: E501 + __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', + '_schema', '_schema_id', '_ref_reference_subject_func', + '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'skip.known.types': True, + 'subject.name.strategy': topic_subject_name_strategy, + 'reference.subject.name.strategy': reference_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer, + 'use.deprecated.format': False, + } + + async def __init__( + self, + msg_type: Message, + schema_registry_client: AsyncSchemaRegistryClient, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._skip_known_types = conf_copy.pop('skip.known.types') + if not isinstance(self._skip_known_types, bool): + raise ValueError("skip.known.types must be a boolean value") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2021 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._ref_reference_subject_func = conf_copy.pop( + 'reference.subject.name.strategy') + if not callable(self._ref_reference_subject_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._schema_id = None + self._known_subjects = set() + self._msg_class = msg_type + self._parsed_schemas = ParsedSchemaCache() + + descriptor = msg_type.DESCRIPTOR + self._index_array = _create_index_array(descriptor) + self._schema = Schema(_schema_to_str(descriptor.file), + schema_type='PROTOBUF') + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): + """ + Writes val to buf, either using zigzag or uvarint encoding. + + Args: + buf (BytesIO): buffer to write to. + val (int): integer to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + if zigzag: + val = (val << 1) ^ (val >> 63) + + while (val & ~0x7f) != 0: + buf.write(_bytes((val & 0x7f) | 0x80)) + val >>= 7 + buf.write(_bytes(val)) + + @staticmethod + def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): + """ + Encodes each int as a uvarint onto buf + + Args: + buf (BytesIO): buffer to write to. + ints ([int]): ints to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + assert len(ints) > 0 + # The root element at the 0 position does not need a length prefix. + if ints == [0]: + buf.write(_bytes(0x00)) + return + + AsyncProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) + + for value in ints: + AsyncProtobufSerializer._write_varint(buf, value, zigzag=zigzag) + + async def _resolve_dependencies( + self, ctx: SerializationContext, + file_desc: FileDescriptor + ) -> List[SchemaReference]: + """ + Resolves and optionally registers schema references recursively. + + Args: + ctx (SerializationContext): Serialization context. + + file_desc (FileDescriptor): file descriptor to traverse. + """ + + schema_refs = [] + for dep in file_desc.dependencies: + if self._skip_known_types and _is_builtin(dep.name): + continue + dep_refs = await self._resolve_dependencies(ctx, dep) + subject = self._ref_reference_subject_func(ctx, dep) + schema = Schema(_schema_to_str(dep), + references=dep_refs, + schema_type='PROTOBUF') + if self._auto_register: + await self._registry.register_schema(subject, schema) + + reference = await self._registry.lookup_schema(subject, schema) + # schema_refs are per file descriptor + schema_refs.append(SchemaReference(dep.name, + subject, + reference.version)) + return schema_refs + + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(message, ctx) + + async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an instance of a class derived from Protobuf Message, and prepends + it with Confluent Schema Registry framing. + + Args: + message (Message): An instance of a class derived from Protobuf Message. + + ctx (SerializationContext): Metadata relevant to the serialization. + operation. + + Raises: + SerializerError if any error occurs during serialization. + + Returns: + None if messages is None, else a byte array containing the Protobuf + serialized message with Confluent Schema Registry framing. + """ + + if message is None: + return None + + if not isinstance(message, self._msg_class): + raise ValueError("message must be of type {} not {}" + .format(self._msg_class, type(message))) + + subject = self._subject_name_func(ctx, message.DESCRIPTOR.full_name) if ctx else None + latest_schema = None + if subject is not None: + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + + if latest_schema is not None: + self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) + + elif subject not in self._known_subjects and ctx is not None: + references = await self._resolve_dependencies(ctx, message.DESCRIPTOR.file) + self._schema = Schema( + self._schema.schema_str, + self._schema.schema_type, + references + ) + + if self._auto_register: + registered_schema = await self._registry.register_schema_full_response( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(PROTOBUF_TYPE, registered_schema.schema_id, registered_schema.guid) + else: + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(PROTOBUF_TYPE, registered_schema.schema_id, registered_schema.guid) + + self._known_subjects.add(subject) + + if latest_schema is not None: + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + fd = pool.FindFileByName(fd_proto.name) + desc = fd.message_types_by_name[message.DESCRIPTOR.name] + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, desc, msg, field_transform)) + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) + + with _ContextStringIO() as fo: + fo.write(message.SerializeToString()) + self._schema_id.message_indexes = self._index_array + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + await _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + +@asyncinit +class AsyncProtobufDeserializer(AsyncBaseDeserializer): + """ + Deserializer for Protobuf serialized data with Confluent Schema Registry framing. + + Args: + message_type (Message derived type): Protobuf Message type. + conf (dict): Configuration dictionary. + + ProtobufDeserializer configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka. schema_registry | + | | | namespace . | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | + | | | deserialize message indexes without zig-zag encoding.| + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If Protobuf messages in the topic to consume were | + | | | produced with confluent-kafka-python <1.8 then this | + | | | property must be set to True until all old messages | + | | | have been processed and producers have been upgraded.| + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + + See Also: + `Protobuf API reference `_ + """ + + __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer, + 'use.deprecated.format': False, + } + + async def __init__( + self, + message_type: Message, + conf: Optional[dict] = None, + schema_registry_client: Optional[AsyncSchemaRegistryClient] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2022 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + descriptor = message_type.DESCRIPTOR + self._msg_class = GetMessageClass(descriptor) + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a serialized protobuf message with Confluent Schema Registry + framing. + + Args: + data (bytes): Serialized protobuf message with Confluent Schema + Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Returns: + Message: Protobuf Message instance. + + Raises: + SerializerError: If there was an error reading the Confluent framing + data, or parsing the protobuf serialized message. + """ + + if data is None: + return None + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + + schema_id = SchemaId(PROTOBUF_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + msg_index = schema_id.message_indexes + + if self._registry is not None: + writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') + fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) + writer_schema = pool.FindFileByName(fd_proto.name) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + if subject is None: + subject = self._subject_name_func(ctx, writer_desc.full_name) + if subject is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + else: + writer_schema_raw = None + writer_schema = None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + reader_schema = pool.FindFileByName(fd_proto.name) + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if reader_schema is not None: + # Initialize reader desc to first message in file + reader_desc = self._get_message_desc(pool, reader_schema, [0]) + # Attempt to find a reader desc with the same name as the writer + reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) + + if migrations: + msg = GetMessageClass(writer_desc)() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + obj_dict = json_format.MessageToDict(msg, True) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + msg = GetMessageClass(reader_desc)() + msg = json_format.ParseDict(obj_dict, msg) + else: + # Protobuf Messages are self-describing; no need to query schema + msg = self._msg_class() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_desc, message, field_transform)) + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) + return msg + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + await _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + def _get_message_desc( + self, pool: DescriptorPool, fd: FileDescriptor, + msg_index: List[int] + ) -> Descriptor: + file_desc_proto = descriptor_pb2.FileDescriptorProto() + fd.CopyToProto(file_desc_proto) + (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) + package = file_desc_proto.package + qualified_name = package + "." + full_name if package else full_name + return pool.FindMessageTypeByName(qualified_name) + + def _get_message_desc_proto( + self, + path: str, + desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], + msg_index: List[int] + ) -> Tuple[str, descriptor_pb2.DescriptorProto]: + index = msg_index[0] + if isinstance(desc, descriptor_pb2.FileDescriptorProto): + msg = desc.message_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) + else: + msg = desc.nested_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py new file mode 100644 index 000000000..f7e22d039 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -0,0 +1,1165 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import json +import logging +import time +import urllib +from urllib.parse import unquote, urlparse + +import httpx +from typing import List, Dict, Optional, Union, Any, Tuple, Callable + +from cachetools import TTLCache, LRUCache +from httpx import Response + +from authlib.integrations.httpx_client import AsyncOAuth2Client + +from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError +from confluent_kafka.schema_registry.common.schema_registry_client import ( + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, + _BearerFieldProvider, + full_jitter, + _SchemaCache, + Schema, + _StaticFieldProvider, +) + +__all__ = [ + '_urlencode', + '_AsyncCustomOAuthClient', + '_AsyncOAuthClient', + '_AsyncBaseRestClient', + '_AsyncRestClient', + 'AsyncSchemaRegistryClient', +] + +# TODO: consider adding `six` dependency or employing a compat file +# Python 2.7 is officially EOL so compatibility issue will be come more the norm. +# We need a better way to handle these issues. +# Six is one possibility but the compat file pattern used by requests +# is also quite nice. +# +# six: https://pypi.org/project/six/ +# compat file : https://github.com/psf/requests/blob/master/requests/compat.py +try: + string_type = basestring # noqa + + def _urlencode(value: str) -> str: + return urllib.quote(value, safe='') +except NameError: + string_type = str + + def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') + +log = logging.getLogger(__name__) + + +class _AsyncCustomOAuthClient(_BearerFieldProvider): + def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): + self.custom_function = custom_function + self.custom_config = custom_config + + async def get_bearer_fields(self) -> dict: + return await self.custom_function(self.custom_config) + + +class _AsyncOAuthClient(_BearerFieldProvider): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self.token = None + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint = token_endpoint + self.max_retries = max_retries + self.retries_wait_ms = retries_wait_ms + self.retries_max_wait_ms = retries_max_wait_ms + self.token_expiry_threshold = 0.8 + + async def get_bearer_fields(self) -> dict: + return { + 'bearer.auth.token': await self.get_access_token(), + 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool + } + + def token_expired(self) -> bool: + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window + + async def get_access_token(self) -> str: + if not self.token or self.token_expired(): + await self.generate_access_token() + + return self.token['access_token'] + + async def generate_access_token(self) -> None: + for i in range(self.max_retries + 1): + try: + self.token = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return + except Exception as e: + if i >= self.max_retries: + raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " + f"attempts due to error: {str(e)}") + await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + + +class _AsyncBaseRestClient(object): + + def __init__(self, conf: dict): + # copy dict to avoid mutating the original + conf_copy = conf.copy() + + base_url = conf_copy.pop('url', None) + if base_url is None: + raise ValueError("Missing required configuration property url") + if not isinstance(base_url, string_type): + raise TypeError("url must be a str, not " + str(type(base_url))) + base_urls = [] + for url in base_url.split(','): + url = url.strip().rstrip('/') + if not url.startswith('http') and not url.startswith('mock'): + raise ValueError("Invalid url {}".format(url)) + base_urls.append(url) + if not base_urls: + raise ValueError("Missing required configuration property url") + self.base_urls = base_urls + + self.verify = True + ca = conf_copy.pop('ssl.ca.location', None) + if ca is not None: + self.verify = ca + + key: Optional[str] = conf_copy.pop('ssl.key.location', None) + client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) + self.cert: Union[str, Tuple[str, str], None] = None + + if client_cert is not None and key is not None: + self.cert = (client_cert, key) + + if client_cert is not None and key is None: + self.cert = client_cert + + if key is not None and client_cert is None: + raise ValueError("ssl.certificate.location required when" + " configuring ssl.key.location") + + parsed = urlparse(self.base_urls[0]) + try: + userinfo = (unquote(parsed.username), unquote(parsed.password)) + except (AttributeError, TypeError): + userinfo = ("", "") + if 'basic.auth.user.info' in conf_copy: + if userinfo != ('', ''): + raise ValueError("basic.auth.user.info configured with" + " userinfo credentials in the URL." + " Remove userinfo credentials from the url or" + " remove basic.auth.user.info from the" + " configuration") + + userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) + + if len(userinfo) != 2: + raise ValueError("basic.auth.user.info must be in the form" + " of {username}:{password}") + + self.auth = userinfo if userinfo != ('', '') else None + + # The following adds support for proxy config + # If specified: it uses the specified proxy details when making requests + self.proxy = None + proxy = conf_copy.pop('proxy', None) + if proxy is not None: + self.proxy = proxy + + self.timeout = None + timeout = conf_copy.pop('timeout', None) + if timeout is not None: + self.timeout = timeout + + self.cache_capacity = 1000 + cache_capacity = conf_copy.pop('cache.capacity', None) + if cache_capacity is not None: + if not isinstance(cache_capacity, (int, float)): + raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) + self.cache_capacity = cache_capacity + + self.cache_latest_ttl_sec = None + cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) + if cache_latest_ttl_sec is not None: + if not isinstance(cache_latest_ttl_sec, (int, float)): + raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) + self.cache_latest_ttl_sec = cache_latest_ttl_sec + + self.max_retries = 3 + max_retries = conf_copy.pop('max.retries', None) + if max_retries is not None: + if not isinstance(max_retries, (int, float)): + raise TypeError("max.retries must be a number, not " + str(type(max_retries))) + self.max_retries = max_retries + + self.retries_wait_ms = 1000 + retries_wait_ms = conf_copy.pop('retries.wait.ms', None) + if retries_wait_ms is not None: + if not isinstance(retries_wait_ms, (int, float)): + raise TypeError("retries.wait.ms must be a number, not " + + str(type(retries_wait_ms))) + self.retries_wait_ms = retries_wait_ms + + self.retries_max_wait_ms = 20000 + retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) + if retries_max_wait_ms is not None: + if not isinstance(retries_max_wait_ms, (int, float)): + raise TypeError("retries.max.wait.ms must be a number, not " + + str(type(retries_max_wait_ms))) + self.retries_max_wait_ms = retries_max_wait_ms + + self.bearer_field_provider = None + logical_cluster = None + identity_pool = None + self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) + if self.bearer_auth_credentials_source is not None: + self.auth = None + + if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: + headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + missing_headers = [header for header in headers if header not in conf_copy] + if missing_headers: + raise ValueError("Missing required bearer configuration properties: {}" + .format(", ".join(missing_headers))) + + logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') + if not isinstance(logical_cluster, str): + raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) + + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') + if not isinstance(identity_pool, str): + raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _AsyncOAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + elif self.bearer_auth_credentials_source == 'CUSTOM': + custom_bearer_properties = ['bearer.auth.custom.provider.function', + 'bearer.auth.custom.provider.config'] + missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] + if missing_custom_properties: + raise ValueError("Missing required custom OAuth configuration properties: {}". + format(", ".join(missing_custom_properties))) + + custom_function = conf_copy.pop('bearer.auth.custom.provider.function') + if not callable(custom_function): + raise TypeError("bearer.auth.custom.provider.function must be a callable, not " + + str(type(custom_function))) + + custom_config = conf_copy.pop('bearer.auth.custom.provider.config') + if not isinstance(custom_config, dict): + raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + + str(type(custom_config))) + + self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) + else: + raise ValueError('Unrecognized bearer.auth.credentials.source') + + # Any leftover keys are unknown to _RestClient + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + async def get(self, url: str, query: Optional[dict] = None) -> Any: + raise NotImplementedError() + + async def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + raise NotImplementedError() + + async def delete(self, url: str) -> Any: + raise NotImplementedError() + + async def put(self, url: str, body: Optional[dict] = None) -> Any: + raise NotImplementedError() + + +class _AsyncRestClient(_AsyncBaseRestClient): + """ + HTTP client for Confluent Schema Registry. + + See SchemaRegistryClient for configuration details. + + Args: + conf (dict): Dictionary containing _RestClient configuration + """ + + def __init__(self, conf: dict): + super().__init__(conf) + + self.session = httpx.AsyncClient( + verify=self.verify, + cert=self.cert, + auth=self.auth, + proxy=self.proxy, + timeout=self.timeout + ) + + async def handle_bearer_auth(self, headers: dict) -> None: + bearer_fields = await self.bearer_field_provider.get_bearer_fields() + required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + + missing_fields = [] + for field in required_fields: + if field not in bearer_fields: + missing_fields.append(field) + + if missing_fields: + raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" + .format(", ".join(missing_fields))) + + headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + + async def get(self, url: str, query: Optional[dict] = None) -> Any: + return await self.send_request(url, method='GET', query=query) + + async def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + return await self.send_request(url, method='POST', body=body) + + async def delete(self, url: str) -> Any: + return await self.send_request(url, method='DELETE') + + async def put(self, url: str, body: Optional[dict] = None) -> Any: + return await self.send_request(url, method='PUT', body=body) + + async def send_request( + self, url: str, method: str, body: Optional[dict] = None, + query: Optional[dict] = None + ) -> Any: + """ + Sends HTTP request to the SchemaRegistry, trying each base URL in turn. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + url (str): Request path + + method (str): HTTP method + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + dict: Schema Registry response content. + """ + + headers = {'Accept': "application/vnd.schemaregistry.v1+json," + " application/vnd.schemaregistry+json," + " application/json"} + + if body is not None: + body = json.dumps(body) + headers = {'Content-Length': str(len(body)), + 'Content-Type': "application/vnd.schemaregistry.v1+json"} + + if self.bearer_auth_credentials_source: + await self.handle_bearer_auth(headers) + + response = None + for i, base_url in enumerate(self.base_urls): + try: + response = await self.send_http_request( + base_url, url, method, headers, body, query) + + if is_success(response.status_code): + return response.json() + + if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: + break + except Exception as e: + if i == len(self.base_urls) - 1: + # Raise the exception since we have no more urls to try + raise e + + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + # Schema Registry may return malformed output when it hits unexpected errors + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + + async def send_http_request( + self, base_url: str, url: str, method: str, headers: Optional[dict], + body: Optional[str] = None, query: Optional[dict] = None + ) -> Response: + """ + Sends HTTP request to the SchemaRegistry. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + base_url (str): Schema Registry base URL + + url (str): Request path + + method (str): HTTP method + + headers (dict): Headers + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + Response: Schema Registry response content. + """ + response = None + for i in range(self.max_retries + 1): + response = await self.session.request( + method, url="/".join([base_url, url]), + headers=headers, content=body, params=query) + + if is_success(response.status_code): + return response + + if not is_retriable(response.status_code) or i >= self.max_retries: + return response + + await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + return response + + +class AsyncSchemaRegistryClient(object): + """ + A Confluent Schema Registry client. + + Configuration properties (* indicates a required field): + + +------------------------------+------+-------------------------------------------------+ + | Property name | type | Description | + +==============================+======+=================================================+ + | ``url`` * | str | Comma-separated list of Schema Registry URLs. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to CA certificate file used | + | ``ssl.ca.location`` | str | to verify the Schema Registry's | + | | | private key. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's private key | + | | | (PEM) used for authentication. | + | ``ssl.key.location`` | str | | + | | | ``ssl.certificate.location`` must also be set. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's public key (PEM) used for | + | | | authentication. | + | ``ssl.certificate.location`` | str | | + | | | May be set without ssl.key.location if the | + | | | private key is stored within the PEM as well. | + +------------------------------+------+-------------------------------------------------+ + | | | Client HTTP credentials in the form of | + | | | ``username:password``. | + | ``basic.auth.user.info`` | str | | + | | | By default userinfo is extracted from | + | | | the URL if present. | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``proxy`` | str | Proxy such as http://localhost:8030. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``timeout`` | int | Request timeout. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | Maximum time to wait for the first retry. | + | | | When jitter is applied, the actual wait may | + | ``retries.wait.ms`` | int | be less. | + | | | | + | | | Defaults to 1000. | + +------------------------------+------+-------------------------------------------------+ + + Args: + conf (dict): Schema Registry client configuration. + + See Also: + `Confluent Schema Registry documentation `_ + """ # noqa: E501 + + def __init__(self, conf: dict): + self._conf = conf + self._rest_client = _AsyncRestClient(conf) + self._cache = _SchemaCache() + cache_capacity = self._rest_client.cache_capacity + cache_ttl = self._rest_client.cache_latest_ttl_sec + if cache_ttl is not None: + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + else: + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + if self._rest_client is not None: + await self._rest_client.session.aclose() + + def config(self): + return self._conf + + async def register_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> int: + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = await self.register_schema_full_response(subject_name, schema, normalize_schemas) + return registered_schema.schema_id + + async def register_schema_full_response( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> 'RegisteredSchema': + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + schema_id = self._cache.get_id_by_schema(subject_name, schema) + if schema_id is not None: + result = self._cache.get_schema_by_id(subject_name, schema_id) + if result is not None: + return RegisteredSchema(schema_id, result[0], result[1], subject_name, None) + + request = schema.to_dict() + + response = await self._rest_client.post( + 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), + body=request) + + registered_schema = RegisteredSchema.from_dict(response) + + # The registered schema may not be fully populated + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema + self._cache.set_schema(subject_name, registered_schema.schema_id, + registered_schema.guid, s) + + return registered_schema + + async def get_schema( + self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``schema_id`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + schema_id (int): Schema id + subject_name (str): Subject name the schema is registered under + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``schema_id`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + result = self._cache.get_schema_by_id(subject_name, schema_id) + if result is not None: + return result[1] + + query = {'subject': subject_name} if subject_name is not None else None + if fmt is not None: + if query is not None: + query['format'] = fmt + else: + query = {'format': fmt} + response = await self._rest_client.get('schemas/ids/{}'.format(schema_id), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_schema(subject_name, schema_id, + registered_schema.guid, registered_schema.schema) + + return registered_schema.schema + + async def get_schema_by_guid( + self, guid: str, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``guid`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + guid (str): Schema guid + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``guid`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + schema = self._cache.get_schema_by_guid(guid) + if schema is not None: + return schema + + if fmt is not None: + query = {'format': fmt} + response = await self._rest_client.get('schemas/guids/{}'.format(guid), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_schema(None, registered_schema.schema_id, + registered_schema.guid, registered_schema.schema) + + return registered_schema.schema + + async def lookup_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False, deleted: bool = False + ) -> 'RegisteredSchema': + """ + Returns ``schema`` registration information for ``subject``. + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + normalize_schemas (bool): Normalize schema before registering + deleted (bool): Whether to include deleted schemas. + + Returns: + RegisteredSchema: Subject registration information for this schema. + + Raises: + SchemaRegistryError: If schema or subject can't be found + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) + if registered_schema is not None: + return registered_schema + + request = schema.to_dict() + + response = await self._rest_client.post( + 'subjects/{}?normalize={}&deleted={}'.format( + _urlencode(subject_name), normalize_schemas, deleted), + body=request + ) + + result = RegisteredSchema.from_dict(response) + + # Ensure the schema matches the input + registered_schema = RegisteredSchema( + schema_id=result.schema_id, + guid=result.guid, + subject=result.subject, + version=result.version, + schema=schema, + ) + + self._cache.set_registered_schema(schema, registered_schema) + + return registered_schema + + async def get_subjects(self) -> List[str]: + """ + List all subjects registered with the Schema Registry + + Returns: + list(str): Registered subject names + + Raises: + SchemaRegistryError: if subjects can't be found + + See Also: + `GET subjects API Reference `_ + """ # noqa: E501 + + return await self._rest_client.get('subjects') + + async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: + """ + Deletes the specified subject and its associated compatibility level if + registered. It is recommended to use this API only when a topic needs + to be recycled or in development environments. + + Args: + subject_name (str): subject name + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + list(int): Versions deleted under this subject + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `DELETE Subject API Reference `_ + """ # noqa: E501 + + if permanent: + versions = await self._rest_client.delete( + 'subjects/{}?permanent=true'.format(_urlencode(subject_name)) + ) + self._cache.remove_by_subject(subject_name) + else: + versions = await self._rest_client.delete( + 'subjects/{}'.format(_urlencode(subject_name)) + ) + + return versions + + async def get_latest_version( + self, subject_name: str, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject + + Args: + subject_name (str): Subject name. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._latest_version_cache.get(subject_name, None) + if registered_schema is not None: + return registered_schema + + query = {'format': fmt} if fmt is not None else None + response = await self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_version_cache[subject_name] = registered_schema + + return registered_schema + + async def get_latest_with_metadata( + self, subject_name: str, metadata: Dict[str, str], + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject with the given metadata + + Args: + subject_name (str): Subject name. + metadata (dict): The key-value pairs for the metadata. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + """ # noqa: E501 + + cache_key = (subject_name, frozenset(metadata.items()), deleted) + registered_schema = self._latest_with_metadata_cache.get(cache_key, None) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + keys = metadata.keys() + if keys: + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] + + response = await self._rest_client.get( + 'subjects/{}/metadata'.format(_urlencode(subject_name)), query + ) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_with_metadata_cache[cache_key] = registered_schema + + return registered_schema + + async def get_version( + self, subject_name: str, version: int, + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves a specific schema registered under ``subject_name``. + + Args: + subject_name (str): Subject name. + version (int): version number. Defaults to latest version. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + response = await self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_registered_schema(registered_schema.schema, registered_schema) + + return registered_schema + + async def get_versions(self, subject_name: str) -> List[int]: + """ + Get a list of all versions registered with this subject. + + Args: + subject_name (str): Subject name. + + Returns: + list(int): Registered versions + + Raises: + SchemaRegistryError: If subject can't be found + + See Also: + `GET Subject Versions API Reference `_ + """ # noqa: E501 + + return await self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) + + async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: + """ + Deletes a specific version registered to ``subject_name``. + + Args: + subject_name (str) Subject name + + version (int): Version number + + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + int: Version number which was deleted + + Raises: + SchemaRegistryError: if the subject or version cannot be found. + + See Also: + `Delete Subject Version API Reference `_ + """ # noqa: E501 + + if permanent: + response = await self._rest_client.delete( + 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) + ) + self._cache.remove_by_subject_version(subject_name, version) + else: + response = await self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) + + return response + + async def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: + """ + Update global or subject level compatibility level. + + Args: + level (str): Compatibility level. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global compatibility + level policy if not set. + + Returns: + str: The newly configured compatibility level. + + Raises: + SchemaRegistryError: If the compatibility level is invalid. + + See Also: + `PUT Subject Compatibility API Reference `_ + """ # noqa: E501 + + if level is None: + raise ValueError("level must be set") + + if subject_name is None: + return await self._rest_client.put( + 'config', body={'compatibility': level.upper()} + ) + + return await self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body={'compatibility': level.upper()} + ) + + async def get_compatibility(self, subject_name: Optional[str] = None) -> str: + """ + Get the current compatibility level. + + Args: + subject_name (str, optional): Subject name. Returns global policy + if left unset. + + Returns: + str: Compatibility level for the subject if set, otherwise the global compatibility level. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Compatibility API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = await self._rest_client.get(url) + return result['compatibilityLevel'] + + async def test_compatibility( + self, subject_name: str, schema: 'Schema', + version: Union[int, str] = "latest" + ) -> bool: + """Test the compatibility of a candidate schema for a given subject and version + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + version (int or str, optional): Version number, or the string "latest". Defaults to "latest". + + Returns: + bool: True if the schema is compatible with the specified version + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `POST Test Compatibility API Reference `_ + """ # noqa: E501 + + request = schema.to_dict() + + response = await self._rest_client.post( + 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request + ) + + return response['is_compatible'] + + async def set_config( + self, subject_name: Optional[str] = None, + config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': + """ + Update global or subject config. + + Args: + config (ServerConfig): Config. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global config + if not set. + + Returns: + str: The newly configured config. + + Raises: + SchemaRegistryError: If the config is invalid. + + See Also: + `PUT Subject Config API Reference `_ + """ # noqa: E501 + + if config is None: + raise ValueError("config must be set") + + if subject_name is None: + return await self._rest_client.put( + 'config', body=config.to_dict() + ) + + return await self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body=config.to_dict() + ) + + async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': + """ + Get the current config. + + Args: + subject_name (str, optional): Subject name. Returns global config + if left unset. + + Returns: + ServerConfig: Config for the subject if set, otherwise the global config. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Config API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = await self._rest_client.get(url) + return ServerConfig.from_dict(result) + + def clear_latest_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + + def clear_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + self._cache.clear() + + @staticmethod + def new_client(conf: dict) -> 'AsyncSchemaRegistryClient': + from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient + url = conf.get("url") + if url.startswith("mock://"): + return MockSchemaRegistryClient(conf) + return AsyncSchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py new file mode 100644 index 000000000..40afca25e --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import List, Optional, Set, Dict, Any + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.serde import ErrorAction, \ + FieldTransformer, Migration, NoneAction, RuleAction, \ + RuleConditionError, RuleContext, RuleError, SchemaId +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema, RuleSet +from confluent_kafka.serialization import Serializer, Deserializer, \ + SerializationContext, SerializationError + +__all__ = [ + 'AsyncBaseSerde', + 'AsyncBaseSerializer', + 'AsyncBaseDeserializer', +] + +log = logging.getLogger(__name__) + + +class AsyncBaseSerde(object): + __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', + '_registry', '_rule_registry', '_subject_name_func', + '_field_transformer'] + + async def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: + if self._use_schema_id is not None: + schema = await self._registry.get_schema(self._use_schema_id, subject, fmt) + return await self._registry.lookup_schema(subject, schema, False, True) + if self._use_latest_with_metadata is not None: + return await self._registry.get_latest_with_metadata( + subject, self._use_latest_with_metadata, True, fmt) + if self._use_latest_version: + return await self._registry.get_latest_version(subject, fmt) + return None + + def _execute_rules( + self, ser_ctx: SerializationContext, subject: str, + rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] + ) -> Any: + if message is None or target is None: + return message + rules: Optional[List[Rule]] = None + if rule_mode == RuleMode.UPGRADE: + if target is not None and target.rule_set is not None: + rules = target.rule_set.migration_rules + elif rule_mode == RuleMode.DOWNGRADE: + if source is not None and source.rule_set is not None: + rules = source.rule_set.migration_rules + rules = rules[:] if rules else [] + rules.reverse() + else: + if target is not None and target.rule_set is not None: + rules = target.rule_set.domain_rules + if rule_mode == RuleMode.READ: + # Execute read rules in reverse order for symmetry + rules = rules[:] if rules else [] + rules.reverse() + + if not rules: + return message + + for index in range(len(rules)): + rule = rules[index] + if self._is_disabled(rule): + continue + if rule.mode == RuleMode.WRITEREAD: + if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: + continue + elif rule.mode == RuleMode.UPDOWN: + if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: + continue + elif rule.mode != rule_mode: + continue + + ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, + index, rules, inline_tags, field_transformer) + rule_executor = self._rule_registry.get_executor(rule.type.upper()) + if rule_executor is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Could not find rule executor of type {rule.type}"), + 'ERROR') + return message + try: + result = rule_executor.transform(ctx, message) + if rule.kind == RuleKind.CONDITION: + if not result: + raise RuleConditionError(rule) + elif rule.kind == RuleKind.TRANSFORM: + message = result + self._run_action( + ctx, rule_mode, rule, + self._get_on_failure(rule) if message is None else self._get_on_success(rule), + message, None, + 'ERROR' if message is None else 'NONE') + except SerializationError: + raise + except Exception as e: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), + message, e, 'ERROR') + return message + + def _get_on_success(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_success is not None: + return override.on_success + return rule.on_success + + def _get_on_failure(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_failure is not None: + return override.on_failure + return rule.on_failure + + def _is_disabled(self, rule: Rule) -> Optional[bool]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.disabled is not None: + return override.disabled + return rule.disabled + + def _run_action( + self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, + action: Optional[str], message: Any, + ex: Optional[Exception], default_action: str + ): + action_name = self._get_rule_action_name(rule, rule_mode, action) + if action_name is None: + action_name = default_action + rule_action = self._get_rule_action(ctx, action_name) + if rule_action is None: + log.error("Could not find rule action of type %s", action_name) + raise RuleError(f"Could not find rule action of type {action_name}") + try: + rule_action.run(ctx, message, ex) + except SerializationError: + raise + except Exception as e: + log.warning("Could not run post-rule action %s: %s", action_name, e) + + def _get_rule_action_name( + self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] + ) -> Optional[str]: + if action_name is None or action_name == "": + return None + if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: + parts = action_name.split(',') + if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + return parts[0] + elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): + return parts[1] + return action_name + + def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: + if action_name == 'ERROR': + return ErrorAction() + elif action_name == 'NONE': + return NoneAction() + return self._rule_registry.get_action(action_name) + + +class AsyncBaseSerializer(AsyncBaseSerde, Serializer): + __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] + + +class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): + __slots__ = ['_schema_id_deserializer'] + + async def _get_writer_schema( + self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: + if schema_id.id is not None: + return await self._registry.get_schema(schema_id.id, subject, fmt) + elif schema_id.guid is not None: + return await self._registry.get_schema_by_guid(str(schema_id.guid), fmt) + else: + raise SerializationError("Schema ID or GUID is not set") + + def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + if rule_set is None: + return False + if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): + return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN + for rule in rule_set.migration_rules or []) + elif mode == RuleMode.UPDOWN: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + elif mode in (RuleMode.WRITE, RuleMode.READ): + return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD + for rule in rule_set.domain_rules or []) + elif mode == RuleMode.WRITEREAD: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return False + + async def _get_migrations( + self, subject: str, source_info: Schema, + target: RegisteredSchema, fmt: Optional[str] + ) -> List[Migration]: + source = self._registry.lookup_schema(subject, source_info, False, True) + migrations = [] + if source.version < target.version: + migration_mode = RuleMode.UPGRADE + first = source + last = target + elif source.version > target.version: + migration_mode = RuleMode.DOWNGRADE + first = target + last = source + else: + return migrations + previous: Optional[RegisteredSchema] = None + versions = await self._get_schemas_between(subject, first, last, fmt) + for i in range(len(versions)): + version = versions[i] + if i == 0: + previous = version + continue + if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) + previous = version + if migration_mode == RuleMode.DOWNGRADE: + migrations.reverse() + return migrations + + async def _get_schemas_between( + self, subject: str, first: RegisteredSchema, + last: RegisteredSchema, fmt: Optional[str] = None + ) -> List[RegisteredSchema]: + if last.version - first.version <= 1: + return [first, last] + version1 = first.version + version2 = last.version + result = [first] + for i in range(version1 + 1, version2): + result.append(await self._registry.get_version(subject, i, True, fmt)) + result.append(last) + return result + + def _execute_migrations( + self, ser_ctx: SerializationContext, subject: str, + migrations: List[Migration], message: Any + ) -> Any: + for migration in migrations: + message = self._execute_rules(ser_ctx, subject, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) + return message diff --git a/src/confluent_kafka/schema_registry/_sync/README.md b/src/confluent_kafka/schema_registry/_sync/README.md new file mode 100644 index 000000000..905a46481 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/README.md @@ -0,0 +1,7 @@ +# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 651dee7ca..b069b8c63 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -32,7 +32,8 @@ from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ @@ -63,6 +64,7 @@ def _resolve_named_schema( return named_schemas + class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -361,6 +363,7 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema + class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry @@ -505,7 +508,8 @@ def __init__( def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __deserialize( + self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 61c46cd96..b8520b814 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -71,6 +71,7 @@ def _resolve_named_schema( return ref_registry + class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -401,6 +402,7 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator + class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 7d4070465..b1a7cf0b0 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -38,7 +38,8 @@ from confluent_kafka.serialization import SerializationError, \ SerializationContext -from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ '_resolve_named_schema', @@ -73,6 +74,7 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) + class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -458,6 +460,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool + class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 65d700f83..f2caffc97 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -283,10 +283,11 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + str(type(self.token_endpoint))) - self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) + self.bearer_field_provider = _OAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") @@ -761,9 +762,11 @@ def lookup_schema( request = schema.to_dict() - response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' - .format(_urlencode(subject_name), normalize_schemas, deleted), - body=request) + response = self._rest_client.post( + 'subjects/{}?normalize={}&deleted={}'.format( + _urlencode(subject_name), normalize_schemas, deleted), + body=request + ) result = RegisteredSchema.from_dict(response) @@ -817,12 +820,14 @@ def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int """ # noqa: E501 if permanent: - versions = self._rest_client.delete('subjects/{}?permanent=true' - .format(_urlencode(subject_name))) + versions = self._rest_client.delete( + 'subjects/{}?permanent=true'.format(_urlencode(subject_name)) + ) self._cache.remove_by_subject(subject_name) else: - versions = self._rest_client.delete('subjects/{}' - .format(_urlencode(subject_name))) + versions = self._rest_client.delete( + 'subjects/{}'.format(_urlencode(subject_name)) + ) return versions @@ -851,9 +856,9 @@ def get_latest_version( return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - 'latest'), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -892,8 +897,9 @@ def get_latest_with_metadata( query['key'] = [_urlencode(key) for key in keys] query['value'] = [_urlencode(metadata[key]) for key in keys] - response = self._rest_client.get('subjects/{}/metadata' - .format(_urlencode(subject_name)), query) + response = self._rest_client.get( + 'subjects/{}/metadata'.format(_urlencode(subject_name)), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -929,9 +935,9 @@ def get_version( return registered_schema query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -980,14 +986,14 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals """ # noqa: E501 if permanent: - response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' - .format(_urlencode(subject_name), - version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) + ) self._cache.remove_by_subject_version(subject_name, version) else: - response = self._rest_client.delete('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) return response @@ -1016,12 +1022,13 @@ def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[ raise ValueError("level must be set") if subject_name is None: - return self._rest_client.put('config', - body={'compatibility': level.upper()}) + return self._rest_client.put( + 'config', body={'compatibility': level.upper()} + ) - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body={'compatibility': level.upper()}) + return self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body={'compatibility': level.upper()} + ) def get_compatibility(self, subject_name: Optional[str] = None) -> str: """ @@ -1106,12 +1113,13 @@ def set_config( raise ValueError("config must be set") if subject_name is None: - return self._rest_client.put('config', - body=config.to_dict()) + return self._rest_client.put( + 'config', body=config.to_dict() + ) - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body=config.to_dict()) + return self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body=config.to_dict() + ) def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': """ diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index 84ce15726..f0512f872 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -187,8 +187,9 @@ class BaseSerializer(BaseSerde, Serializer): class BaseDeserializer(BaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] - def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, - fmt: Optional[str] = None) -> Schema: + def _get_writer_schema( + self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: if schema_id.id is not None: return self._registry.get_schema(schema_id.id, subject, fmt) elif schema_id.guid is not None: diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 0a9ba2db2..1c0024695 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -16,4 +16,5 @@ # limitations under the License. from .common.avro import * # noqa +from ._async.avro import * # noqa from ._sync.avro import * # noqa diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py index e69de29bb..abbd1c448 100644 --- a/src/confluent_kafka/schema_registry/common/__init__.py +++ b/src/confluent_kafka/schema_registry/common/__init__.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +def asyncinit(cls): + """ + Decorator to make a class async-initializable. + """ + __new__ = cls.__new__ + + async def init(obj, *arg, **kwarg): + await obj.__init__(*arg, **kwarg) + return obj + + def new(klass, *arg, **kwarg): + obj = __new__(klass) + coro = init(obj, *arg, **kwarg) + return coro + + cls.__new__ = new + return cls diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index e60c8eafd..ca4d74ee9 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -16,4 +16,5 @@ # limitations under the License. from .common.json_schema import * # noqa +from ._async.json_schema import * # noqa from ._sync.json_schema import * # noqa diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index c781e47be..5127398b6 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -16,4 +16,5 @@ # limitations under the License. from .common.protobuf import * # noqa +from ._async.protobuf import * # noqa from ._sync.protobuf import * # noqa diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index b7a008097..be6c70119 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -17,6 +17,7 @@ from .common.schema_registry_client import * # noqa +from ._async.schema_registry_client import * # noqa from ._sync.schema_registry_client import * # noqa from .error import SchemaRegistryError # noqa diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 25fe4a48f..0574fc3b1 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -17,4 +17,5 @@ # from .common.serde import * # noqa +from ._async.serde import * # noqa from ._sync.serde import * # noqa diff --git a/tests/common/_async/__init__.py b/tests/common/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py new file mode 100644 index 000000000..81f036ca0 --- /dev/null +++ b/tests/common/_async/consumer.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import asyncio +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout # noqa: F401 + +from confluent_kafka.cimpl import Consumer +from confluent_kafka.error import ConsumeError, KeyDeserializationError, ValueDeserializationError +from confluent_kafka.serialization import MessageField, SerializationContext + +ASYNC_CONSUMER_POLL_INTERVAL_SECONDS: int = 0.2 +ASYNC_CONSUMER_POLL_INFINITE_TIMEOUT_SECONDS: int = -1 + + +class AsyncConsumer(Consumer): + def __init__( + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval_seconds: int = ASYNC_CONSUMER_POLL_INTERVAL_SECONDS + ): + super().__init__(conf) + + self._loop = loop or asyncio.get_event_loop() + self._poll_interval = poll_interval_seconds + + def __aiter__(self): + return self + + async def __anext__(self): + return await self.poll(None) + + async def poll(self, poll_timeout: int = -1): + poll_timeout = None if poll_timeout == -1 else poll_timeout + async with timeout(poll_timeout): + while True: + # Zero timeout here is what makes it non-blocking + msg = super().poll(0) + if msg is not None: + return msg + else: + await asyncio.sleep(self._poll_interval) + + +class TestAsyncDeserializingConsumer(AsyncConsumer): + def __init__(self, conf): + conf_copy = conf.copy() + self._key_deserializer = conf_copy.pop('key.deserializer', None) + self._value_deserializer = conf_copy.pop('value.deserializer', None) + super().__init__(conf_copy) + + async def poll(self, poll_timeout=-1): + msg = await super().poll(poll_timeout) + + if msg is None: + return None + + if msg.error() is not None: + raise ConsumeError(msg.error(), kafka_message=msg) + + ctx = SerializationContext(msg.topic(), MessageField.VALUE, msg.headers()) + value = msg.value() + if self._value_deserializer is not None: + try: + value = await self._value_deserializer(value, ctx) + except Exception as se: + raise ValueDeserializationError(exception=se, kafka_message=msg) + + key = msg.key() + ctx.field = MessageField.KEY + if self._key_deserializer is not None: + try: + key = await self._key_deserializer(key, ctx) + except Exception as se: + raise KeyDeserializationError(exception=se, kafka_message=msg) + + msg.set_key(key) + msg.set_value(value) + return msg + + def consume(self, num_messages=1, consume_timeout=-1): + """ + :py:func:`Consumer.consume` not implemented, use + :py:func:`DeserializingConsumer.poll` instead + """ + + raise NotImplementedError diff --git a/tests/common/_async/producer.py b/tests/common/_async/producer.py new file mode 100644 index 000000000..882685df2 --- /dev/null +++ b/tests/common/_async/producer.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from confluent_kafka.cimpl import Producer +import inspect +import asyncio + +from confluent_kafka.error import KeySerializationError, ValueSerializationError +from confluent_kafka.serialization import MessageField, SerializationContext + +ASYNC_PRODUCER_POLL_INTERVAL: int = 0.2 + + +class AsyncProducer(Producer): + def __init__( + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval: int = ASYNC_PRODUCER_POLL_INTERVAL + ): + super().__init__(conf) + + self._loop = loop or asyncio.get_event_loop() + self._poll_interval = poll_interval + + self._poll_task = None + self._waiters: int = 0 + + async def produce( + self, topic, value=None, key=None, partition=-1, + on_delivery=None, timestamp=0, headers=None + ): + fut = self._loop.create_future() + self._waiters += 1 + try: + if self._poll_task is None or self._poll_task.done(): + self._poll_task = asyncio.create_task(self._poll_dr(self._poll_interval)) + + def wrapped_on_delivery(err, msg): + if on_delivery is not None: + if inspect.iscoroutinefunction(on_delivery): + asyncio.run_coroutine_threadsafe( + on_delivery(err, msg), + self._loop + ) + else: + self._loop.call_soon_threadsafe(on_delivery, err, msg) + + if err: + self._loop.call_soon_threadsafe(fut.set_exception, err) + else: + self._loop.call_soon_threadsafe(fut.set_result, msg) + + super().produce( + topic, + value, + key, + headers=headers, + partition=partition, + timestamp=timestamp, + on_delivery=wrapped_on_delivery + ) + return await fut + finally: + self._waiters -= 1 + + async def _poll_dr(self, interval: int): + """Poll delivery reports at interval seconds""" + while self._waiters: + super().poll(0) + await asyncio.sleep(interval) + + +class TestAsyncSerializingProducer(AsyncProducer): + def __init__(self, conf): + conf_copy = conf.copy() + + self._key_serializer = conf_copy.pop('key.serializer', None) + self._value_serializer = conf_copy.pop('value.serializer', None) + + super(TestAsyncSerializingProducer, self).__init__(conf_copy) + + async def produce( + self, topic, key=None, value=None, partition=-1, + on_delivery=None, timestamp=0, headers=None): + ctx = SerializationContext(topic, MessageField.KEY, headers) + if self._key_serializer is not None: + try: + key = await self._key_serializer(key, ctx) + except Exception as se: + raise KeySerializationError(se) + ctx.field = MessageField.VALUE + if self._value_serializer is not None: + try: + value = await self._value_serializer(value, ctx) + except Exception as se: + raise ValueSerializationError(se) + + return await super().produce( + topic, value, key, + headers=headers, + partition=partition, + timestamp=timestamp, + on_delivery=on_delivery + ) diff --git a/tests/integration/cluster_fixture.py b/tests/integration/cluster_fixture.py index efaa55f70..0d441ca1d 100644 --- a/tests/integration/cluster_fixture.py +++ b/tests/integration/cluster_fixture.py @@ -24,9 +24,12 @@ from confluent_kafka import Producer, SerializingProducer from confluent_kafka.admin import AdminClient, NewTopic from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry._async.schema_registry_client import AsyncSchemaRegistryClient from tests.common import TestConsumer from tests.common.schema_registry import TestDeserializingConsumer +from tests.common._async.consumer import TestAsyncDeserializingConsumer +from tests.common._async.producer import TestAsyncSerializingProducer class KafkaClusterFixture(object): @@ -89,6 +92,32 @@ def producer(self, conf=None, key_serializer=None, value_serializer=None): return SerializingProducer(client_conf) + def async_producer(self, conf=None, key_serializer=None, value_serializer=None): + """ + Returns a producer bound to this cluster. + + Args: + conf (dict): Producer configuration overrides + + key_serializer (Serializer): serializer to apply to message key + + value_serializer (Deserializer): serializer to apply to + message value + + Returns: + Producer: A new SerializingProducer instance + + """ + client_conf = self.client_conf(conf) + + if key_serializer is not None: + client_conf['key.serializer'] = key_serializer + + if value_serializer is not None: + client_conf['value.serializer'] = value_serializer + + return TestAsyncSerializingProducer(client_conf) + def cimpl_consumer(self, conf=None): """ Returns a consumer bound to this cluster. @@ -143,6 +172,39 @@ def consumer(self, conf=None, key_deserializer=None, value_deserializer=None): return TestDeserializingConsumer(consumer_conf) + def async_consumer(self, conf=None, key_deserializer=None, value_deserializer=None): + """ + Returns a consumer bound to this cluster. + + Args: + conf (dict): Consumer config overrides + + key_deserializer (Deserializer): deserializer to apply to + message key + + value_deserializer (Deserializer): deserializer to apply to + message value + + Returns: + Consumer: A new DeserializingConsumer instance + + """ + consumer_conf = self.client_conf({ + 'group.id': str(uuid1()), + 'auto.offset.reset': 'earliest' + }) + + if conf is not None: + consumer_conf.update(conf) + + if key_deserializer is not None: + consumer_conf['key.deserializer'] = key_deserializer + + if value_deserializer is not None: + consumer_conf['value.deserializer'] = value_deserializer + + return TestAsyncDeserializingConsumer(consumer_conf) + def admin(self, conf=None): if conf: # When conf is passed create a new AdminClient @@ -286,6 +348,15 @@ def schema_registry(self, conf=None): sr_conf.update(conf) return SchemaRegistryClient(sr_conf) + def async_schema_registry(self, conf=None): + if not hasattr(self._cluster, 'sr'): + return None + + sr_conf = {'url': self._cluster.sr.get('url')} + if conf is not None: + sr_conf.update(conf) + return AsyncSchemaRegistryClient(sr_conf) + def client_conf(self, conf=None): """ Default client configuration diff --git a/tests/integration/integration_test.py b/tests/integration/integration_test.py index d291c8751..c7e44e021 100755 --- a/tests/integration/integration_test.py +++ b/tests/integration/integration_test.py @@ -28,6 +28,7 @@ import gc import struct import re +import pytest import confluent_kafka @@ -213,6 +214,7 @@ def verify_producer(): DrOnlyTestSuccess_gced = 0 +@pytest.mark.skip(reason="This test must be run as a standalone script") def test_producer_dr_only_error(): """ The C delivery.report.only.error configuration property diff --git a/tests/integration/schema_registry/_async/__init__.py b/tests/integration/schema_registry/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/schema_registry/_async/test_api_client.py b/tests/integration/schema_registry/_async/test_api_client.py new file mode 100644 index 000000000..7ba3ee6c9 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_api_client.py @@ -0,0 +1,493 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from uuid import uuid1 + +import pytest + +from confluent_kafka.schema_registry import Schema +from confluent_kafka.schema_registry.error import SchemaRegistryError +from tests.integration.conftest import kafka_cluster_fixture + + +@pytest.fixture(scope="module") +def kafka_cluster_cp_7_0_1(): + """ + Returns a Trivup cluster with CP version 7.0.1. + SR version 7.0.1 is the last returning 500 instead of 422 + for the invalid schema passed to test_api_get_register_schema_invalid + """ + for fixture in kafka_cluster_fixture( + brokers_env="BROKERS_7_0_1", + sr_url_env="SR_URL_7_0_1", + trivup_cluster_conf={'cp_version': '7.0.1'} + ): + yield fixture + + +def _subject_name(prefix): + return prefix + "-" + str(uuid1()) + + +async def test_api_register_schema(kafka_cluster, load_file): + """ + Registers a schema, verifies the registration + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + avsc = 'basic_schema.avsc' + subject = _subject_name(avsc) + schema = Schema(load_file(avsc), schema_type='AVRO') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + assert schema.schema_str, registered_schema.schema.schema_str + + +async def test_api_register_normalized_schema(kafka_cluster, load_file): + """ + Registers a schema, verifies the registration + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + avsc = 'basic_schema.avsc' + subject = _subject_name(avsc) + schema = Schema(load_file(avsc), schema_type='AVRO') + + schema_id = await sr.register_schema(subject, schema, True) + registered_schema = await sr.lookup_schema(subject, schema, True) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + assert schema.schema_str, registered_schema.schema.schema_str + + +async def test_api_register_schema_incompatible(kafka_cluster, load_file): + """ + Attempts to register an incompatible Schema verifies the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema1 = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + schema2 = Schema(load_file('adv_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_register_incompatible') + + await sr.register_schema(subject, schema1) + + with pytest.raises(SchemaRegistryError, match="Schema being registered is" + " incompatible with an" + " earlier schema") as e: + # The default Schema Registry compatibility type is BACKWARD. + # this allows 1) fields to be deleted, 2) optional fields to + # be added. schema2 adds non-optional fields to schema1, so + # registering schema2 after schema1 should fail. + await sr.register_schema(subject, schema2) + assert e.value.http_status_code == 409 # conflict + assert e.value.error_code == 409 + + +async def test_api_register_schema_invalid(kafka_cluster, load_file): + """ + Attempts to register an invalid schema, validates the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('invalid_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_invalid_schema') + + with pytest.raises(SchemaRegistryError) as e: + await sr.register_schema(subject, schema) + assert e.value.http_status_code == 422 + assert e.value.error_code == 42201 + + +async def test_api_get_schema(kafka_cluster, load_file): + """ + Registers a schema then retrieves it using the schema id returned by the + call to register the Schema. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('get_schema') + + schema_id = await sr.register_schema(subject, schema) + registration = await sr.lookup_schema(subject, schema) + + assert registration.schema_id == schema_id + assert registration.subject == subject + assert schema.schema_str, registration.schema.schema_str + + +async def test_api_get_schema_not_found(kafka_cluster, load_file): + """ + Attempts to fetch an unknown schema by id, validates the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + with pytest.raises(SchemaRegistryError, match="Schema .*not found.*") as e: + await sr.get_schema(999999) + + assert e.value.http_status_code == 404 + assert e.value.error_code == 40403 + + +async def test_api_get_registration_subject_not_found(kafka_cluster, load_file): + """ + Attempts to obtain information about a schema's subject registration for + an unknown subject. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + + subject = _subject_name("registration_subject_not_found") + + with pytest.raises(SchemaRegistryError, match="Subject .*not found.*") as e: + await sr.lookup_schema(subject, schema) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40401 + + +@pytest.mark.parametrize("kafka_cluster_name, http_status_code, error_code", [ + ["kafka_cluster_cp_7_0_1", 500, 500], + ["kafka_cluster", 422, 42201], +]) +async def test_api_get_register_schema_invalid( + kafka_cluster_name, + http_status_code, + error_code, + load_file, + request): + """ + Attempts to obtain registration information with an invalid schema + with different CP versions. + + Args: + kafka_cluster_name (str): name of the Kafka Cluster fixture to use + http_status_code (int): HTTP status return code expected in this version + error_code (int): error code expected in this version + load_file (callable(str)): Schema fixture constructor + request (FixtureRequest): PyTest object giving access to the test context + """ + kafka_cluster = request.getfixturevalue(kafka_cluster_name) + sr = kafka_cluster.async_schema_registry() + subject = _subject_name("registration_invalid_schema") + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + + # register valid schema so we don't hit subject not found exception + await sr.register_schema(subject, schema) + schema2 = Schema(load_file('invalid_schema.avsc'), schema_type='AVRO') + + with pytest.raises(SchemaRegistryError, match="Invalid schema") as e: + await sr.lookup_schema(subject, schema2) + + assert e.value.http_status_code == http_status_code + assert e.value.error_code == error_code + + +async def test_api_get_subjects(kafka_cluster, load_file): + """ + Populates KafkaClusterFixture SR instance with a fixed number of subjects + then verifies the response includes them all. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + avscs = ['basic_schema.avsc', 'primitive_string.avsc', + 'primitive_bool.avsc', 'primitive_float.avsc'] + + subjects = [] + for avsc in avscs: + schema = Schema(load_file(avsc), schema_type='AVRO') + subject = _subject_name(avsc) + subjects.append(subject) + + await sr.register_schema(subject, schema) + + registered = await sr.get_subjects() + + assert all([s in registered for s in subjects]) + + +async def test_api_get_subject_versions(kafka_cluster, load_file): + """ + Registers a Schema with a subject, lists the versions associated with that + subject and ensures the versions and their schemas match what was + registered. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor. + + """ + sr = kafka_cluster.async_schema_registry() + + subject = _subject_name("list-version-test") + await sr.set_compatibility(level="NONE") + + avscs = ['basic_schema.avsc', 'primitive_string.avsc', + 'primitive_bool.avsc', 'primitive_float.avsc'] + + schemas = [] + for avsc in avscs: + schema = Schema(load_file(avsc), schema_type='AVRO') + schemas.append(schema) + await sr.register_schema(subject, schema) + + versions = await sr.get_versions(subject) + assert len(versions) == len(avscs) + for schema in schemas: + registered_schema = await sr.lookup_schema(subject, schema) + assert registered_schema.subject == subject + assert registered_schema.version in versions + + # revert global compatibility level back to the default. + await sr.set_compatibility(level="BACKWARD") + + +async def test_api_delete_subject(kafka_cluster, load_file): + """ + Registers a Schema under a specific subject then deletes it. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name("test-delete") + + await sr.register_schema(subject, schema) + assert subject in await sr.get_subjects() + + await sr.delete_subject(subject) + assert subject not in await sr.get_subjects() + + +async def test_api_delete_subject_not_found(kafka_cluster): + sr = kafka_cluster.async_schema_registry() + + subject = _subject_name("test-delete_invalid_subject") + + with pytest.raises(SchemaRegistryError, match="Subject .*not found.*") as e: + await sr.delete_subject(subject) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40401 + + +async def test_api_get_subject_version(kafka_cluster, load_file): + """ + Registers a schema, fetches that schema by it's subject version id. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + + await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + registered_schema2 = await sr.get_version(subject, registered_schema.version) + + assert registered_schema2.schema_id == registered_schema.schema_id + assert registered_schema2.schema.schema_str == registered_schema.schema.schema_str + assert registered_schema2.version == registered_schema.version + + +async def test_api_get_subject_version_no_version(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + # ensures subject exists and has a single version + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + await sr.register_schema(subject, schema) + + with pytest.raises(SchemaRegistryError, match="Version .*not found") as e: + await sr.get_version(subject, version=3) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40402 + + +async def test_api_get_subject_version_invalid(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + # ensures subject exists and has a single version + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + await sr.register_schema(subject, schema) + + with pytest.raises(SchemaRegistryError, + match="The specified version .*is not" + " a valid version id.*") as e: + await sr.get_version(subject, version='a') + assert e.value.http_status_code == 422 + assert e.value.error_code == 42202 + + +async def test_api_post_subject_registration(kafka_cluster, load_file): + """ + Registers a schema, fetches that schema by it's subject version id. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_registration') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + + +async def test_api_delete_subject_version(kafka_cluster, load_file): + """ + Registers a Schema under a specific subject then deletes it. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = str(uuid1()) + + await sr.register_schema(subject, schema) + await sr.delete_version(subject, 1) + + assert subject not in await sr.get_subjects() + + +async def test_api_subject_config_update(kafka_cluster, load_file): + """ + Updates a subjects compatibility policy then ensures the same policy + is returned when queried. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = str(uuid1()) + + await sr.register_schema(subject, schema) + await sr.set_compatibility( + subject_name=subject, + level="FULL_TRANSITIVE" + ) + + assert await sr.get_compatibility(subject_name=subject) == "FULL_TRANSITIVE" + + +async def test_api_config_invalid(kafka_cluster): + """ + Sets an invalid compatibility level, validates the exception. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + """ + sr = kafka_cluster.async_schema_registry() + + with pytest.raises(SchemaRegistryError, match="Invalid compatibility" + " level") as e: + await sr.set_compatibility(level="INVALID") + e.value.http_status_code = 422 + e.value.error_code = 42203 + + +async def test_api_config_update(kafka_cluster): + """ + Updates the global compatibility policy then ensures the same policy + is returned when queried. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + """ + sr = kafka_cluster.async_schema_registry() + + for level in ["BACKWARD", "BACKWARD_TRANSITIVE", "FORWARD", "FORWARD_TRANSITIVE"]: + await sr.set_compatibility(level=level) + assert await sr.get_compatibility() == level + + # revert global compatibility level back to the default. + await sr.set_compatibility(level="BACKWARD") + + +async def test_api_register_logical_schema(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('logical_date.avsc'), schema_type='AVRO') + subject = _subject_name('test_logical_registration') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject diff --git a/tests/integration/schema_registry/_async/test_avro_serializers.py b/tests/integration/schema_registry/_async/test_avro_serializers.py new file mode 100644 index 000000000..0cdf97011 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_avro_serializers.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +from confluent_kafka import TopicPartition +from confluent_kafka.serialization import (MessageField, + SerializationContext) +from confluent_kafka.schema_registry.avro import (AsyncAvroSerializer, + AsyncAvroDeserializer) +from confluent_kafka.schema_registry import Schema, SchemaReference + + +class User(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "User", + "type": "record", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_number", "type": "int"}, + {"name": "favorite_color", "type": "string"} + ] + } + """ + + def __init__(self, name, favorite_number, favorite_color): + self.name = name + self.favorite_number = favorite_number + self.favorite_color = favorite_color + + def __eq__(self, other): + return all([ + self.name == other.name, + self.favorite_number == other.favorite_number, + self.favorite_color == other.favorite_color]) + + +class AwardProperties(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "AwardProperties", + "type": "record", + "fields": [ + {"name": "year", "type": "int"}, + {"name": "points", "type": "int"} + ] + } + """ + + def __init__(self, points, year): + self.points = points + self.year = year + + def __eq__(self, other): + return all([ + self.points == other.points, + self.year == other.year + ]) + + +class Award(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "Award", + "type": "record", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "properties", "type": "AwardProperties"} + ] + } + """ + + def __init__(self, name, properties): + self.name = name + self.properties = properties + + def __eq__(self, other): + return all([ + self.name == other.name, + self.properties == other.properties + ]) + + +class AwardedUser(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "AwardedUser", + "type": "record", + "fields": [ + {"name": "award", "type": "Award"}, + {"name": "user", "type": "User"} + ] + } + """ + + def __init__(self, award, user): + self.award = award + self.user = user + + def __eq__(self, other): + return all([ + self.award == other.award, + self.user == other.user + ]) + + +async def _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster): + sr = kafka_cluster.async_schema_registry() + + user = User('Bowie', 47, 'purple') + award_properties = AwardProperties(10, 2023) + award = Award("Best In Show", award_properties) + awarded_user = AwardedUser(award, user) + + user_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.User", "user", 1) + award_properties_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.AwardProperties", + "award_properties", 1) + award_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.Award", "award", 1) + + await sr.register_schema("user", Schema(User.schema_str, 'AVRO')) + await sr.register_schema("award_properties", Schema(AwardProperties.schema_str, 'AVRO')) + await sr.register_schema("award", Schema(Award.schema_str, 'AVRO', [award_properties_schema_ref])) + + references = [user_schema_ref, award_schema_ref] + schema = Schema(AwardedUser.schema_str, 'AVRO', references) + return awarded_user, schema + + +async def _references_test_common(kafka_cluster, awarded_user, serializer_schema, deserializer_schema): + """ + Common (both reader and writer) avro schema reference test. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + topic = kafka_cluster.create_topic_and_wait_propogation("reference-avro") + sr = kafka_cluster.async_schema_registry() + + value_serializer = await AsyncAvroSerializer( + sr, + serializer_schema, + lambda user, ctx: dict( + award=dict( + name=user.award.name, + properties=dict(year=user.award.properties.year, points=user.award.properties.points) + ), + user=dict( + name=user.user.name, + favorite_number=user.user.favorite_number, + favorite_color=user.user.favorite_color + ) + ) + ) + + value_deserializer = \ + await AsyncAvroDeserializer( + sr, + deserializer_schema, + lambda user, ctx: AwardedUser( + award=Award( + name=user.get('award').get('name'), + properties=AwardProperties( + year=user.get('award').get('properties').get('year'), + points=user.get('award').get('properties').get('points') + ) + ), + user=User( + name=user.get('user').get('name'), + favorite_number=user.get('user').get('favorite_number'), + favorite_color=user.get('user').get('favorite_color') + ) + ) + ) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=awarded_user, partition=0) + + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + awarded_user2 = msg.value() + + assert awarded_user2 == awarded_user + + +@pytest.mark.parametrize("avsc, data, record_type", + [('basic_schema.avsc', {'name': 'abc'}, "record"), + ('primitive_string.avsc', u'Jämtland', "string"), + ('primitive_bool.avsc', True, "bool"), + ('primitive_float.avsc', 32768.2342, "float"), + ('primitive_double.avsc', 68.032768, "float")]) +async def test_avro_record_serialization(kafka_cluster, load_file, avsc, data, record_type): + """ + Tests basic Avro serializer functionality + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + load_file (callable(str)): Avro file reader + avsc (str) avsc: Avro schema file + data (object): data to be serialized + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file(avsc) + value_serializer = await AsyncAvroSerializer(sr, schema_str) + + value_deserializer = await AsyncAvroDeserializer(sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=data, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + if record_type == 'record': + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + +@pytest.mark.parametrize("avsc, data,record_type", + [('basic_schema.avsc', dict(name='abc'), 'record'), + ('primitive_string.avsc', u'Jämtland', 'string'), + ('primitive_bool.avsc', True, 'bool'), + ('primitive_float.avsc', 768.2340, 'float'), + ('primitive_double.avsc', 6.868, 'float')]) +async def test_delivery_report_serialization(kafka_cluster, load_file, avsc, data, record_type): + """ + Tests basic Avro serializer functionality + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + load_file (callable(str)): Avro file reader + avsc (str) avsc: Avro schema file + data (object): data to be serialized + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro-dr") + sr = kafka_cluster.async_schema_registry() + schema_str = load_file(avsc) + + value_serializer = await AsyncAvroSerializer(sr, schema_str) + + value_deserializer = await AsyncAvroDeserializer(sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + async def assert_cb(err, msg): + actual = value_deserializer(msg.value(), + SerializationContext(topic, MessageField.VALUE, msg.headers())) + + if record_type == "record": + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + await producer.produce(topic, value=data, partition=0, on_delivery=assert_cb) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + # schema may include default which need not exist in the original + if record_type == 'record': + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + +async def test_avro_record_serialization_custom(kafka_cluster): + """ + Tests basic Avro serializer to_dict and from_dict object hook functionality. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro") + sr = kafka_cluster.async_schema_registry() + + user = User('Bowie', 47, 'purple') + value_serializer = await AsyncAvroSerializer( + sr, + User.schema_str, + lambda user, ctx: + dict( + name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color + ) + ) + + value_deserializer = await AsyncAvroDeserializer( + sr, + User.schema_str, + lambda user_dict, ctx: User(**user_dict) + ) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=user, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + user2 = msg.value() + + assert user2 == user + + +async def test_avro_reference(kafka_cluster): + """ + Tests Avro schema reference with both serializer and deserializer schemas provided. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + awarded_user, schema = await _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster) + + await _references_test_common(kafka_cluster, awarded_user, schema, schema) + + +async def test_avro_reference_deserializer_none(kafka_cluster): + """ + Tests Avro schema reference with serializer schema provided and deserializer schema set to None. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + awarded_user, schema = await _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster) + + await _references_test_common(kafka_cluster, awarded_user, schema, None) diff --git a/tests/integration/schema_registry/_async/test_json_serializers.py b/tests/integration/schema_registry/_async/test_json_serializers.py new file mode 100644 index 000000000..464b41836 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_json_serializers.py @@ -0,0 +1,491 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from confluent_kafka import TopicPartition + +from confluent_kafka.error import ConsumeError, ValueSerializationError +from confluent_kafka.schema_registry import SchemaReference, Schema, AsyncSchemaRegistryClient +from confluent_kafka.schema_registry.json_schema import (AsyncJSONSerializer, + AsyncJSONDeserializer) + + +class _TestProduct(object): + def __init__(self, product_id, name, price, tags, dimensions, location): + self.product_id = product_id + self.name = name + self.price = price + self.tags = tags + self.dimensions = dimensions + self.location = location + + def __eq__(self, other): + return all([ + self.product_id == other.product_id, + self.name == other.name, + self.price == other.price, + self.tags == other.tags, + self.dimensions == other.dimensions, + self.location == other.location + ]) + + +class _TestCustomer(object): + def __init__(self, name, id): + self.name = name + self.id = id + + def __eq__(self, other): + return all([ + self.name == other.name, + self.id == other.id + ]) + + +class _TestOrderDetails(object): + def __init__(self, id, customer): + self.id = id + self.customer = customer + + def __eq__(self, other): + return all([ + self.id == other.id, + self.customer == other.customer + ]) + + +class _TestOrder(object): + def __init__(self, order_details, product): + self.order_details = order_details + self.product = product + + def __eq__(self, other): + return all([ + self.order_details == other.order_details, + self.product == other.product + ]) + + +class _TestReferencedProduct(object): + def __init__(self, name, product): + self.name = name + self.product = product + + def __eq__(self, other): + return all([ + self.name == other.name, + self.product == other.product + ]) + + +def _testProduct_to_dict(product_obj, ctx): + """ + Returns testProduct instance in dict format. + + Args: + product_obj (_TestProduct): testProduct instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: product_obj as a dictionary. + + """ + return {"productId": product_obj.product_id, + "productName": product_obj.name, + "price": product_obj.price, + "tags": product_obj.tags, + "dimensions": product_obj.dimensions, + "warehouseLocation": product_obj.location} + + +def _testCustomer_to_dict(customer_obj, ctx): + """ + Returns testCustomer instance in dict format. + + Args: + customer_obj (_TestCustomer): testCustomer instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: customer_obj as a dictionary. + + """ + return {"name": customer_obj.name, + "id": customer_obj.id} + + +def _testOrderDetails_to_dict(orderdetails_obj, ctx): + """ + Returns testOrderDetails instance in dict format. + + Args: + orderdetails_obj (_TestOrderDetails): testOrderDetails instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: orderdetails_obj as a dictionary. + + """ + return {"id": orderdetails_obj.id, + "customer": _testCustomer_to_dict(orderdetails_obj.customer, ctx)} + + +def _testOrder_to_dict(order_obj, ctx): + """ + Returns testOrder instance in dict format. + + Args: + order_obj (_TestOrder): testOrder instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: order_obj as a dictionary. + + """ + return {"order_details": _testOrderDetails_to_dict(order_obj.order_details, ctx), + "product": _testProduct_to_dict(order_obj.product, ctx)} + + +def _testProduct_from_dict(product_dict, ctx): + """ + Returns testProduct instance from its dict format. + + Args: + product_dict (dict): testProduct in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestProduct: product_obj instance. + + """ + return _TestProduct(product_dict['productId'], + product_dict['productName'], + product_dict['price'], + product_dict['tags'], + product_dict['dimensions'], + product_dict['warehouseLocation']) + + +def _testCustomer_from_dict(customer_dict, ctx): + """ + Returns testCustomer instance from its dict format. + + Args: + customer_dict (dict): testCustomer in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestCustomer: customer_obj instance. + + """ + return _TestCustomer(customer_dict['name'], + customer_dict['id']) + + +def _testOrderDetails_from_dict(orderdetails_dict, ctx): + """ + Returns testOrderDetails instance from its dict format. + + Args: + orderdetails_dict (dict): testOrderDetails in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestOrderDetails: orderdetails_obj instance. + + """ + return _TestOrderDetails(orderdetails_dict['id'], + _testCustomer_from_dict(orderdetails_dict['customer'], ctx)) + + +def _testOrder_from_dict(order_dict, ctx): + """ + Returns testOrder instance from its dict format. + + Args: + order_dict (dict): testOrder in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestOrder: order_obj instance. + + """ + return _TestOrder(_testOrderDetails_from_dict(order_dict['order_details'], ctx), + _testProduct_from_dict(order_dict['product'], ctx)) + + +async def test_json_record_serialization(kafka_cluster, load_file): + """ + Tests basic JsonSerializer and JsonDeserializer basic functionality. + + product.json from: + https://json-schema.org/learn/getting-started-step-by-step.html + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr) + value_deserializer = await AsyncJSONDeserializer(schema_str) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"productId": 1, + "productName": "An ice sculpture", + "price": 12.50, + "tags": ["cold", "ice"], + "dimensions": { + "length": 7.0, + "width": 12.0, + "height": 9.5 + }, + "warehouseLocation": { + "latitude": -78.75, + "longitude": 20.4 + }} + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([actual[k] == v for k, v in record.items()]) + + +async def test_json_record_serialization_incompatible(kafka_cluster, load_file): + """ + Tests Serializer validation functionality. + + product.json from: + https://json-schema.org/learn/getting-started-step-by-step.html + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr) + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"contractorId": 1, + "contractorName": "David Davidson", + "contractRate": 1250, + "trades": ["mason"]} + + with pytest.raises(ValueSerializationError, + match=r"(.*) is a required property"): + await producer.produce(topic, value=record, partition=0) + + +async def test_json_record_serialization_custom(kafka_cluster, load_file): + """ + Ensures to_dict and from_dict hooks are properly applied by the serializer. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr, to_dict=_testProduct_to_dict) + value_deserializer = await AsyncJSONDeserializer( + schema_str, + from_dict=_testProduct_from_dict + ) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = _TestProduct(product_id=1, + name="The ice sculpture", + price=12.50, + tags=["cold", "ice"], + dimensions={"length": 7.0, + "width": 12.0, + "height": 9.5}, + location={"latitude": -78.75, + "longitude": 20.4}) + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([getattr(actual, attribute) == getattr(record, attribute) + for attribute in vars(record)]) + + +async def test_json_record_deserialization_mismatch(kafka_cluster, load_file): + """ + Ensures to_dict and from_dict hooks are properly applied by the serializer. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("contractor.json") + schema_str2 = load_file("product.json") + + value_serializer = await AsyncJSONSerializer(schema_str, sr) + value_deserializer = await AsyncJSONDeserializer(schema_str2) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"contractorId": 2, + "contractorName": "Magnus Edenhill", + "contractRate": 30, + "trades": ["pickling"]} + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + with pytest.raises( + ConsumeError, + match="'productId' is a required property"): + await consumer.poll() + + +async def _register_referenced_schemas(sr: AsyncSchemaRegistryClient, load_file): + await sr.register_schema("product", Schema(load_file("product.json"), 'JSON')) + await sr.register_schema("customer", Schema(load_file("customer.json"), 'JSON')) + await sr.register_schema("order_details", Schema(load_file("order_details.json"), 'JSON', [ + SchemaReference("http://example.com/customer.schema.json", "customer", 1)])) + + order_schema = Schema(load_file("order.json"), 'JSON', + [SchemaReference("http://example.com/order_details.schema.json", "order_details", 1), + SchemaReference("http://example.com/product.schema.json", "product", 1)]) + return order_schema + + +async def test_json_reference(kafka_cluster, load_file): + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + product = {"productId": 1, + "productName": "An ice sculpture", + "price": 12.50, + "tags": ["cold", "ice"], + "dimensions": { + "length": 7.0, + "width": 12.0, + "height": 9.5 + }, + "warehouseLocation": { + "latitude": -78.75, + "longitude": 20.4 + }} + customer = {"name": "John Doe", "id": 1} + order_details = {"id": 1, "customer": customer} + order = {"order_details": order_details, "product": product} + + schema = await _register_referenced_schemas(sr, load_file) + + value_serializer = await AsyncJSONSerializer(schema, sr) + value_deserializer = await AsyncJSONDeserializer(schema, schema_registry_client=sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + await producer.produce(topic, value=order, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([actual[k] == v for k, v in order.items()]) + + +async def test_json_reference_custom(kafka_cluster, load_file): + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + product = _TestProduct(product_id=1, + name="The ice sculpture", + price=12.50, + tags=["cold", "ice"], + dimensions={"length": 7.0, + "width": 12.0, + "height": 9.5}, + location={"latitude": -78.75, + "longitude": 20.4}) + customer = _TestCustomer(name="John Doe", id=1) + order_details = _TestOrderDetails(id=1, customer=customer) + order = _TestOrder(order_details=order_details, product=product) + + schema = await _register_referenced_schemas(sr, load_file) + + value_serializer = await AsyncJSONSerializer(schema, sr, to_dict=_testOrder_to_dict) + value_deserializer = await AsyncJSONDeserializer(schema, schema_registry_client=sr, from_dict=_testOrder_from_dict) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + await producer.produce(topic, value=order, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert actual == order diff --git a/tests/integration/schema_registry/_async/test_proto_serializers.py b/tests/integration/schema_registry/_async/test_proto_serializers.py new file mode 100644 index 000000000..0e65686e2 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_proto_serializers.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# +# Copyright 2016 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +from confluent_kafka import TopicPartition, KafkaException, KafkaError +from confluent_kafka.error import ConsumeError +from confluent_kafka.schema_registry.protobuf import AsyncProtobufSerializer, AsyncProtobufDeserializer +from tests.integration.schema_registry.data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ + PublicTestProto_pb2 +from tests.integration.schema_registry.data.proto.DependencyTestProto_pb2 import DependencyMessage +from tests.integration.schema_registry.data.proto.exampleProtoCriteo_pb2 import ClickCas + + +@pytest.mark.parametrize("pb2, data", [ + (TestProto_pb2.TestMessage, {'test_string': "abc", + 'test_bool': True, + 'test_bytes': b'look at these bytes', + 'test_double': 1.0, + 'test_float': 12.0}), + (PublicTestProto_pb2.TestMessage, {'test_string': "abc", + 'test_bool': True, + 'test_bytes': b'look at these bytes', + 'test_double': 1.0, + 'test_float': 12.0}), + (NestedTestProto_pb2.NestedMessage, {'user_id': + NestedTestProto_pb2.UserId( + kafka_user_id='oneof_str'), + 'is_active': True, + 'experiments_active': ['x', 'y', '1'], + 'status': NestedTestProto_pb2.INACTIVE, + 'complex_type': + NestedTestProto_pb2.ComplexType( + one_id='oneof_str', + is_active=False)}) +]) +async def test_protobuf_message_serialization(kafka_cluster, pb2, data): + """ + Validates that we get the same message back that we put in. + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto") + sr = kafka_cluster.async_schema_registry() + + value_serializer = await AsyncProtobufSerializer(pb2, sr, {'use.deprecated.format': False}) + value_deserializer = await AsyncProtobufDeserializer(pb2, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + expect = pb2(**data) + await producer.produce(topic, value=expect, partition=0) + producer.flush() + + msg = await consumer.poll() + actual = msg.value() + + assert [getattr(expect, k) == getattr(actual, k) for k in data.keys()] + + +@pytest.mark.parametrize("pb2, expected_refs", [ + (TestProto_pb2.TestMessage, ['google/protobuf/descriptor.proto']), + (NestedTestProto_pb2.NestedMessage, ['google/protobuf/timestamp.proto']), + (DependencyMessage, ['NestedTestProto.proto', 'PublicTestProto.proto']), + (ClickCas, ['metadata_proto.proto', 'common_proto.proto']) +]) +async def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): + """ + Registers multiple messages with dependencies then queries the Schema + Registry to ensure the references match up. + + """ + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2, sr, {'use.deprecated.format': False}) + producer = kafka_cluster.async_producer(key_serializer=serializer) + + await producer.produce(topic, key=pb2(), partition=0) + producer.flush() + + registered_refs = (await sr.get_schema(serializer._schema_id.id)).references + + assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() + + +async def test_protobuf_serializer_type_mismatch(kafka_cluster): + """ + Ensures an Exception is raised when deserializing an unexpected type. + + """ + pb2_1 = TestProto_pb2.TestMessage + pb2_2 = NestedTestProto_pb2.NestedMessage + + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2_1, sr, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(key_serializer=serializer) + + with pytest.raises(KafkaException, + match=r"message must be of type not \"): + await producer.produce(topic, key=pb2_2()) + + +async def test_protobuf_deserializer_type_mismatch(kafka_cluster): + """ + Ensures an Exception is raised when deserializing an unexpected type. + + """ + pb2_1 = PublicTestProto_pb2.TestMessage + pb2_2 = metadata_proto_pb2.HDFSOptions + + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2_1, sr, {'use.deprecated.format': False}) + deserializer = await AsyncProtobufDeserializer(pb2_2, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(key_serializer=serializer) + consumer = kafka_cluster.async_consumer(key_deserializer=deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + def dr(err, msg): + print("dr msg {} {}".format(msg.key(), msg.value())) + + await producer.produce( + topic, + key=pb2_1(test_string='abc', test_bool=True, test_bytes=b'def'), + partition=0 + ) + producer.flush() + + with pytest.raises(ConsumeError) as e: + await consumer.poll() + assert e.value.code == KafkaError._KEY_DESERIALIZATION diff --git a/tests/integration/schema_registry/_sync/README.md b/tests/integration/schema_registry/_sync/README.md new file mode 100644 index 000000000..905a46481 --- /dev/null +++ b/tests/integration/schema_registry/_sync/README.md @@ -0,0 +1,7 @@ +# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory diff --git a/tests/integration/schema_registry/_sync/test_api_client.py b/tests/integration/schema_registry/_sync/test_api_client.py index 6b75e3b52..d841eb174 100644 --- a/tests/integration/schema_registry/_sync/test_api_client.py +++ b/tests/integration/schema_registry/_sync/test_api_client.py @@ -438,8 +438,10 @@ def test_api_subject_config_update(kafka_cluster, load_file): subject = str(uuid1()) sr.register_schema(subject, schema) - sr.set_compatibility(subject_name=subject, - level="FULL_TRANSITIVE") + sr.set_compatibility( + subject_name=subject, + level="FULL_TRANSITIVE" + ) assert sr.get_compatibility(subject_name=subject) == "FULL_TRANSITIVE" diff --git a/tests/integration/schema_registry/_sync/test_avro_serializers.py b/tests/integration/schema_registry/_sync/test_avro_serializers.py index 322637ae7..b40426080 100644 --- a/tests/integration/schema_registry/_sync/test_avro_serializers.py +++ b/tests/integration/schema_registry/_sync/test_avro_serializers.py @@ -154,27 +154,41 @@ def _references_test_common(kafka_cluster, awarded_user, serializer_schema, dese topic = kafka_cluster.create_topic_and_wait_propogation("reference-avro") sr = kafka_cluster.schema_registry() - value_serializer = AvroSerializer(sr, serializer_schema, - lambda user, ctx: - dict(award=dict(name=user.award.name, - properties=dict(year=user.award.properties.year, - points=user.award.properties.points)), - user=dict(name=user.user.name, - favorite_number=user.user.favorite_number, - favorite_color=user.user.favorite_color))) + value_serializer = AvroSerializer( + sr, + serializer_schema, + lambda user, ctx: dict( + award=dict( + name=user.award.name, + properties=dict(year=user.award.properties.year, points=user.award.properties.points) + ), + user=dict( + name=user.user.name, + favorite_number=user.user.favorite_number, + favorite_color=user.user.favorite_color + ) + ) + ) value_deserializer = \ - AvroDeserializer(sr, deserializer_schema, - lambda user, ctx: - AwardedUser(award=Award(name=user.get('award').get('name'), - properties=AwardProperties( - year=user.get('award').get('properties').get( - 'year'), - points=user.get('award').get('properties').get( - 'points'))), - user=User(name=user.get('user').get('name'), - favorite_number=user.get('user').get('favorite_number'), - favorite_color=user.get('user').get('favorite_color')))) + AvroDeserializer( + sr, + deserializer_schema, + lambda user, ctx: AwardedUser( + award=Award( + name=user.get('award').get('name'), + properties=AwardProperties( + year=user.get('award').get('properties').get('year'), + points=user.get('award').get('properties').get('points') + ) + ), + user=User( + name=user.get('user').get('name'), + favorite_number=user.get('user').get('favorite_number'), + favorite_color=user.get('user').get('favorite_color') + ) + ) + ) producer = kafka_cluster.producer(value_serializer=value_serializer) @@ -304,15 +318,22 @@ def test_avro_record_serialization_custom(kafka_cluster): sr = kafka_cluster.schema_registry() user = User('Bowie', 47, 'purple') - value_serializer = AvroSerializer(sr, User.schema_str, - lambda user, ctx: - dict(name=user.name, - favorite_number=user.favorite_number, - favorite_color=user.favorite_color)) - - value_deserializer = AvroDeserializer(sr, User.schema_str, - lambda user_dict, ctx: - User(**user_dict)) + value_serializer = AvroSerializer( + sr, + User.schema_str, + lambda user, ctx: + dict( + name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color + ) + ) + + value_deserializer = AvroDeserializer( + sr, + User.schema_str, + lambda user_dict, ctx: User(**user_dict) + ) producer = kafka_cluster.producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index ae67c30f2..0c4a2545d 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -336,10 +336,11 @@ def test_json_record_serialization_custom(kafka_cluster, load_file): sr = kafka_cluster.schema_registry() schema_str = load_file("product.json") - value_serializer = JSONSerializer(schema_str, sr, - to_dict=_testProduct_to_dict) - value_deserializer = JSONDeserializer(schema_str, - from_dict=_testProduct_from_dict) + value_serializer = JSONSerializer(schema_str, sr, to_dict=_testProduct_to_dict) + value_deserializer = JSONDeserializer( + schema_str, + from_dict=_testProduct_from_dict + ) producer = kafka_cluster.producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_sync/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py index 7ea741856..9b3ca3197 100644 --- a/tests/integration/schema_registry/_sync/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -92,7 +92,7 @@ def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = sr.get_schema(serializer._schema_id.id).references + registered_refs = (sr.get_schema(serializer._schema_id.id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() @@ -138,10 +138,11 @@ def test_protobuf_deserializer_type_mismatch(kafka_cluster): def dr(err, msg): print("dr msg {} {}".format(msg.key(), msg.value())) - producer.produce(topic, key=pb2_1(test_string='abc', - test_bool=True, - test_bytes=b'def'), - partition=0) + producer.produce( + topic, + key=pb2_1(test_string='abc', test_bool=True, test_bytes=b'def'), + partition=0 + ) producer.flush() with pytest.raises(ConsumeError) as e: diff --git a/tests/test_unasync.py b/tests/test_unasync.py new file mode 100644 index 000000000..ea475e04f --- /dev/null +++ b/tests/test_unasync.py @@ -0,0 +1,190 @@ +from tools.unasync import unasync, unasync_line, unasync_file_check + +import os +import tempfile + +import pytest + + +@pytest.fixture +def temp_dirs(): + """Create temporary directories for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create async and sync directories + async_dir = os.path.join(temp_dir, "async") + sync_dir = os.path.join(temp_dir, "sync") + os.makedirs(async_dir) + os.makedirs(sync_dir) + yield async_dir, sync_dir + + +def test_unasync_line(): + """Test the unasync_line function with various inputs.""" + test_cases = [ + ("async def test():", "def test():"), + ("await some_func()", "some_func()"), + ("from confluent_kafka.schema_registry.common import asyncinit", ""), + ("@asyncinit", ""), + ("import asyncio", ""), + ("asyncio.sleep(1)", "time.sleep(1)"), + ("class AsyncTest:", "class Test:"), + ("class _AsyncTest:", "class _Test:"), + ("async_test_func", "test_func"), + ] + + for input_line, expected in test_cases: + assert unasync_line(input_line) == expected + + +def test_unasync_file_check(temp_dirs): + """Test the unasync_file_check function with various scenarios.""" + async_dir, sync_dir = temp_dirs + + # Test case 1: Files match + async_file = os.path.join(async_dir, "test1.py") + sync_file = os.path.join(sync_dir, "test1.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + time.sleep(1) +""") + + # This should return True + assert unasync_file_check(async_file, sync_file) is True + + # Test case 2: Files don't match + async_file = os.path.join(async_dir, "test2.py") + sync_file = os.path.join(sync_dir, "test2.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + # This is wrong + asyncio.sleep(1) +""") + + # This should return False + assert unasync_file_check(async_file, sync_file) is False + + # Test case 3: Files have different lengths + async_file = os.path.join(async_dir, "test3.py") + sync_file = os.path.join(sync_dir, "test3.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) + return "test" +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + time.sleep(1) +""") + + # This should return False + assert unasync_file_check(async_file, sync_file) is False + + # Test case 4: File not found + with pytest.raises(ValueError, match="Error comparing"): + unasync_file_check("nonexistent.py", "also_nonexistent.py") + + +def test_unasync_generation(temp_dirs): + """Test the unasync generation functionality.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" + +class AsyncTest: + async def test_method(self): + await self.some_async() +""") + + # Run unasync with test directories + dir_pairs = [(async_dir, sync_dir)] + unasync(dir_pairs=dir_pairs, check=False) + + # Check if sync file was created + sync_file = os.path.join(sync_dir, "test.py") + assert os.path.exists(sync_file) + + # Check content + with open(sync_file, "r") as f: + content = f.read() + assert "async def" not in content + assert "await" not in content + assert "AsyncTest" not in content + assert "Test" in content + + # Check if README was created + readme_file = os.path.join(sync_dir, "README.md") + assert os.path.exists(readme_file) + with open(readme_file, "r") as f: + readme_content = f.read() + assert "Auto-generated Directory" in readme_content + assert "Do not edit these files directly" in readme_content + + +def test_unasync_check(temp_dirs): + """Test the unasync check functionality.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" +""") + + # Create an incorrect sync file + sync_file = os.path.join(sync_dir, "test.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + with open(sync_file, "w") as f: + f.write("""def test_func(): + time.sleep(1) + return "test" + # Extra line that shouldn't be here +""") + + # Run unasync check with test directories + dir_pairs = [(async_dir, sync_dir)] + with pytest.raises(SystemExit) as excinfo: + unasync(dir_pairs=dir_pairs, check=True) + assert excinfo.value.code == 1 + + +def test_unasync_missing_sync_file(temp_dirs): + """Test unasync check with missing sync files.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" +""") + + # Run unasync check with test directories + dir_pairs = [(async_dir, sync_dir)] + with pytest.raises(SystemExit) as excinfo: + unasync(dir_pairs=dir_pairs, check=True) + assert excinfo.value.code == 1 diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index a84e20c5a..9d3662337 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -27,13 +27,16 @@ if [[ $RUN_COVERAGE == true ]]; then exit 0 fi +echo "Checking for uncommitted changes in generated _sync directories" +python3 tools/unasync.py --check + python3 -m pip install . if [[ $OS_NAME == linux && $ARCH == x64 ]]; then if [[ -z $TEST_CONSUMER_GROUP_PROTOCOL ]]; then # Run these actions and tests only in this case echo "Building documentation ..." - flake8 --exclude ./_venv,*_pb2.py + flake8 --exclude ./_venv,*_pb2.py,./build pip install -r requirements/requirements-docs.txt make docs diff --git a/tools/unasync.py b/tools/unasync.py new file mode 100644 index 000000000..173d1d48e --- /dev/null +++ b/tools/unasync.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +import os +import re +import sys +import argparse +import difflib + +# List of directories to convert from async to sync +# Each tuple contains the async directory and its sync counterpart +# If you add a new _async directory and want the _sync directory to be +# generated, you must add it to this list. +ASYNC_TO_SYNC = [ + ("src/confluent_kafka/schema_registry/_async", "src/confluent_kafka/schema_registry/_sync"), + ("tests/integration/schema_registry/_async", "tests/integration/schema_registry/_sync") +] + +SUBS = [ + ('from confluent_kafka.schema_registry.common import asyncinit', ''), + ('@asyncinit', ''), + ('import asyncio', ''), + ('asyncio.sleep', 'time.sleep'), + + ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), + ('_Async([A-Z][A-Za-z0-9_]*)', r'_\2'), + ('async_([a-z][A-Za-z0-9_]*)', r'\2'), + + ('async def', 'def'), + ('await ', ''), + ('aclose', 'close'), + ('__aenter__', '__enter__'), + ('__aexit__', '__exit__'), + ('__aiter__', '__iter__'), +] + +COMPILED_SUBS = [ + (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) + for regex, repl in SUBS +] + +USED_SUBS = set() + + +def unasync_line(line): + for index, (regex, repl) in enumerate(COMPILED_SUBS): + old_line = line + line = re.sub(regex, repl, line) + if old_line != line: + USED_SUBS.add(index) + return line + + +def unasync_file(in_path, out_path): + with open(in_path, "r") as in_file: + with open(out_path, "w", newline="") as out_file: + for line in in_file.readlines(): + line = unasync_line(line) + out_file.write(line) + + +def unasync_file_check(in_path, out_path): + """Check if the sync file matches the expected generated content. + + Args: + in_path: Path to the async file + out_path: Path to the sync file + + Returns: + bool: True if files match, False if they don't + + Raises: + ValueError: If there's an error reading the files + """ + try: + with open(in_path, "r") as in_file: + async_content = in_file.read() + expected_content = "".join(unasync_line(line) for line in async_content.splitlines(keepends=True)) + + with open(out_path, "r") as out_file: + actual_content = out_file.read() + + if actual_content != expected_content: + diff = difflib.unified_diff( + expected_content.splitlines(keepends=True), + actual_content.splitlines(keepends=True), + fromfile=in_path, + tofile=out_path, + n=3 # Show 3 lines of context + ) + print(''.join(diff)) + return False + return True + except Exception as e: + print(f"Error comparing {in_path} and {out_path}: {e}") + raise ValueError(f"Error comparing {in_path} and {out_path}: {e}") + + +def check_sync_files(dir_pairs): + """Check if all sync files match their expected generated content. + Returns a list of files that have differences.""" + files_with_diff = [] + + for async_dir, sync_dir in dir_pairs: + for dirpath, _, filenames in os.walk(async_dir): + for filename in filenames: + if not filename.endswith('.py'): + continue + rel_dir = os.path.relpath(dirpath, async_dir) + async_path = os.path.normpath(os.path.join(async_dir, rel_dir, filename)) + sync_path = os.path.normpath(os.path.join(sync_dir, rel_dir, filename)) + + if not os.path.exists(sync_path): + files_with_diff.append(sync_path) + continue + + if not unasync_file_check(async_path, sync_path): + files_with_diff.append(sync_path) + + return files_with_diff + + +def unasync_dir(in_dir, out_dir): + # Create the output directory if it doesn't exist + os.makedirs(out_dir, exist_ok=True) + + # Create README.md in the sync directory + readme_path = os.path.join(out_dir, "README.md") + readme_content = """# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory +""" + with open(readme_path, "w") as f: + f.write(readme_content) + + for dirpath, _, filenames in os.walk(in_dir): + for filename in filenames: + if not filename.endswith('.py'): + continue + rel_dir = os.path.relpath(dirpath, in_dir) + in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) + out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) + # Create the subdirectory if it doesn't exist + os.makedirs(os.path.dirname(out_path), exist_ok=True) + print(in_path, '->', out_path) + unasync_file(in_path, out_path) + + +def unasync(dir_pairs=None, check=False): + """Convert async code to sync code. + + Args: + dir_pairs: List of (async_dir, sync_dir) tuples to process. If None, uses ASYNC_TO_SYNC. + check: If True, only check if sync files are up to date without modifying them. + """ + if dir_pairs is None: + dir_pairs = ASYNC_TO_SYNC + + files_with_diff = [] + if check: + files_with_diff = check_sync_files(dir_pairs) + + if files_with_diff: + print("\n⚠️ Detected differences between async and sync files.") + print("\nFiles that need to be regenerated:") + for file in files_with_diff: + print(f" - {file}") + print("\nPlease run this script again (without the --check flag) to regenerate the sync files.") + sys.exit(1) + else: + print("\n✅ All _sync directories are up to date!") + if not check: + print("Converting async code to sync code...") + for async_dir, sync_dir in dir_pairs: + unasync_dir(async_dir, sync_dir) + + print("\n✅ Generated sync code from async code.") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert async code to sync code') + parser.add_argument( + '--check', + action='store_true', + help='Exit with non-zero status if sync directory has any differences') + args = parser.parse_args() + unasync(check=args.check) diff --git a/tox.ini b/tox.ini index 55c786e54..c646479aa 100644 --- a/tox.ini +++ b/tox.ini @@ -29,3 +29,7 @@ norecursedirs = tests/integration/*/java exclude = venv*,.venv*,env,.env,.tox,.toxenv,.git,build,docs,tools,tmp-build,*_pb2.py,*tmp-KafkaCluster/* max-line-length = 119 accept-encodings = utf-8 +per-file-ignores = + ./src/confluent_kafka/schema_registry/_sync/avro.py: E303 + ./src/confluent_kafka/schema_registry/_sync/json_schema.py: E303 + ./src/confluent_kafka/schema_registry/_sync/protobuf.py: E303