From 4799d921cf2278b8e02cb6d5c4df377d3393b0b6 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 17 Apr 2025 14:49:03 -0700 Subject: [PATCH 01/32] Refactor for async support --- .../schema_registry/__init__.py | 64 +- .../schema_registry/_sync/__init__.py | 17 + .../schema_registry/_sync/avro.py | 577 +++++ .../schema_registry/_sync/json_schema.py | 643 ++++++ .../schema_registry/_sync/protobuf.py | 801 +++++++ .../_sync/schema_registry_client.py | 1115 ++++++++++ .../schema_registry/_sync/serde.py | 252 +++ src/confluent_kafka/schema_registry/avro.py | 796 +------ .../schema_registry/common/__init__.py | 91 + .../schema_registry/common/avro.py | 233 ++ .../schema_registry/common/json_schema.py | 167 ++ .../schema_registry/common/protobuf.py | 358 +++ .../common/schema_registry_client.py | 897 ++++++++ .../schema_registry/common/serde.py | 298 +++ .../schema_registry/json_schema.py | 791 +------ .../schema_registry/protobuf.py | 1137 +--------- .../schema_registry/schema_registry_client.py | 1966 +---------------- src/confluent_kafka/schema_registry/serde.py | 523 +---- .../schema_registry/{ => _sync}/__init__.py | 0 .../{ => _sync}/test_api_client.py | 0 .../{ => _sync}/test_avro_serializers.py | 2 + .../{ => _sync}/test_json_serializers.py | 4 +- .../{ => _sync}/test_proto_serializers.py | 4 +- 23 files changed, 5474 insertions(+), 5262 deletions(-) create mode 100644 src/confluent_kafka/schema_registry/_sync/__init__.py create mode 100644 src/confluent_kafka/schema_registry/_sync/avro.py create mode 100644 src/confluent_kafka/schema_registry/_sync/json_schema.py create mode 100644 src/confluent_kafka/schema_registry/_sync/protobuf.py create mode 100644 src/confluent_kafka/schema_registry/_sync/schema_registry_client.py create mode 100644 src/confluent_kafka/schema_registry/_sync/serde.py create mode 100644 src/confluent_kafka/schema_registry/common/__init__.py create mode 100644 src/confluent_kafka/schema_registry/common/avro.py create mode 100644 src/confluent_kafka/schema_registry/common/json_schema.py create mode 100644 src/confluent_kafka/schema_registry/common/protobuf.py create mode 100644 src/confluent_kafka/schema_registry/common/schema_registry_client.py create mode 100644 src/confluent_kafka/schema_registry/common/serde.py rename tests/integration/schema_registry/{ => _sync}/__init__.py (100%) rename tests/integration/schema_registry/{ => _sync}/test_api_client.py (100%) rename tests/integration/schema_registry/{ => _sync}/test_avro_serializers.py (99%) rename tests/integration/schema_registry/{ => _sync}/test_json_serializers.py (99%) rename tests/integration/schema_registry/{ => _sync}/test_proto_serializers.py (96%) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index e4ad4be17..d405f7433 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -35,7 +35,13 @@ ServerConfig ) -_MAGIC_BYTE = 0 +from .common import ( + _MAGIC_BYTE, + topic_subject_name_strategy, + topic_record_subject_name_strategy, + record_subject_name_strategy, + reference_subject_name_strategy +) __all__ = [ "ConfigCompatibilityLevel", @@ -57,59 +63,3 @@ "topic_record_subject_name_strategy", "record_subject_name_strategy" ] - - -def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {topic}-key|value. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return ctx.topic + "-" + ctx.field - - -def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {topic}-{record_name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return ctx.topic + "-" + record_name if record_name is not None else None - - -def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {record_name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return record_name if record_name is not None else None - - -def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: - """ - Constructs a subject reference name in the form of {reference name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - schema_ref (SchemaReference): SchemaReference instance. - - """ - return schema_ref.name if schema_ref is not None else None diff --git a/src/confluent_kafka/schema_registry/_sync/__init__.py b/src/confluent_kafka/schema_registry/_sync/__init__.py new file mode 100644 index 000000000..2b4389a06 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py new file mode 100644 index 000000000..a7cce69b2 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from json import loads +from struct import pack, unpack +from typing import Dict, Union, Optional, Callable + +from fastavro import schemaless_reader, schemaless_writer + +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, get_inline_tags, parse_schema_with_repo, transform + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, + SchemaRegistryClient) +from confluent_kafka.serialization import (SerializationError, + SerializationContext) +from confluent_kafka.schema_registry.common import _ContextStringIO +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache + + +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient +) -> Dict[str, AvroSchema]: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :return: named_schemas dict. + """ + named_schemas = {} + if schema.references is not None: + for ref in schema.references: + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) + parsed_schema = parse_schema_with_repo( + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + named_schemas.update(ref_named_schemas) + named_schemas[ref.name] = parsed_schema + return named_schemas + + +class AvroSerializer(BaseSerializer): + """ + Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+==================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Note: + Prior to serialization, all values must first be converted to + a dict instance. This may handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with AvroSerializer. + + See ``avro_producer.py`` in the examples directory for example usage. + + Note: + Tuple notation can be used to determine which branch of an ambiguous union to take. + + See `fastavro notation `_ + + Args: + schema_registry_client (SchemaRegistryClient): Schema Registry client instance. + + schema_str (str or Schema): + Avro `Schema Declaration. `_ + Accepts either a string or a :py:class:`Schema` instance. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. + + conf (dict): AvroSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_schema', + '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + def __init__( + self, + schema_registry_client: SchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + schema = None + + self._registry = schema_registry_client + self._schema_id = None + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + parsed_schema = self._get_parsed_schema(schema) + + if isinstance(parsed_schema, list): + # if parsed_schema is a list, we have an Avro union and there + # is no valid schema name. This is fine because the only use of + # schema_name is for supplying the subject name to the registry + # and union types should use topic_subject_name_strategy, which + # just discards the schema name anyway + schema_name = None + else: + # The Avro spec states primitives have a name equal to their type + # i.e. {"type": "string"} has a name of string. + # This function does not comply. + # https://github.com/fastavro/fastavro/issues/415 + schema_dict = loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + parsed_schema = None + + self._schema = schema + self._schema_name = schema_name + self._parsed_schema = parsed_schema + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to Avro binary format, prepending it with Confluent + Schema Registry framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata pertaining to the serialization operation. + + Raises: + SerializerError: If any error occurs serializing obj. + SchemaRegistryError: If there was an error registering the schema with + Schema Registry, or auto.register.schemas is + false and the schema was not registered. + + Returns: + bytes: Confluent Schema Registry encoded Avro bytes + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = self._registry.register_schema( + subject, self._schema, self._normalize_schemas) + else: + registered_schema = self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + parsed_schema = self._get_parsed_schema(latest_schema.schema) + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, parsed_schema, msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, get_inline_tags(parsed_schema), + field_transformer) + else: + parsed_schema = self._parsed_schema + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the record to the rest of the buffer + schemaless_writer(fo, parsed_schema, value) + + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema + + + +class AvroDeserializer(BaseDeserializer): + """ + Deserializer for Avro binary encoded data with Confluent Schema Registry + framing. + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Note: + By default, Avro complex types are returned as dicts. This behavior can + be overridden by registering a callable ``from_dict`` with the deserializer to + convert the dicts to the desired type. + + See ``avro_consumer.py`` in the examples directory in the examples + directory for example usage. + + Args: + schema_registry_client (SchemaRegistryClient): Confluent Schema Registry + client instance. + + schema_str (str, Schema, optional): Avro reader schema declaration Accepts + either a string or a :py:class:`Schema` instance. If not provided, the + writer schema will be used as the reader schema. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to an instance of some object. + + return_record_name (bool): If True, when reading a union of records, the result will + be a tuple where the first value is the name of the record and the second value is + the record itself. Defaults to False. + + See Also: + `Apache Avro Schema Declaration `_ + + `Apache Avro Schema Resolution `_ + """ + + __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', + '_schema', '_parsed_schemas'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + def __init__( + self, + schema_registry_client: SchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + return_record_name: bool = False, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + schema = None + if schema_str is not None: + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + raise TypeError('You must pass either schema string or schema object') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema = self._get_parsed_schema(self._schema) + else: + self._reader_schema = None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature " + "from_dict(SerializationContext, dict) -> object") + self._from_dict = from_dict + + self._return_record_name = return_record_name + if not isinstance(self._return_record_name, bool): + raise ValueError("return_record_name must be a boolean value") + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + return self.__deserialize(data, ctx) + + def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + """ + Deserialize Avro binary encoded data with Confluent Schema Registry framing to + a dict, or object instance according to from_dict, if specified. + + Arguments: + data (bytes): bytes + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError: if an error occurs parsing data. + + Returns: + object: If data is None, then None. Else, a dict, or object instance according + to from_dict, if specified. + """ # noqa: E501 + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None: + latest_schema = self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + writer_schema_raw = self._registry.get_schema(schema_id) + writer_schema = self._get_parsed_schema(writer_schema_raw) + + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("name")) + if subject is not None: + latest_schema = self._get_reader_schema(subject) + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema = self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema = self._reader_schema + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if migrations: + obj_dict = schemaless_reader(payload, + writer_schema, + None, + self._return_record_name) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + else: + obj_dict = schemaless_reader(payload, + writer_schema, + reader_schema, + self._return_record_name) + + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py new file mode 100644 index 000000000..ac2513913 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import struct +from typing import Union, Optional, Tuple, Callable + +from cachetools import LRUCache +from jsonschema import ValidationError +from jsonschema.protocols import Validator +from jsonschema.validators import validator_for +from referencing import Registry, Resource + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, SchemaRegistryClient) + +from confluent_kafka.schema_registry.common.json_schema import ( + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform +) +from confluent_kafka.schema_registry.common import _ContextStringIO +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache +from confluent_kafka.serialization import (SerializationError, + SerializationContext) + + +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient, + ref_registry: Optional[Registry] = None +) -> Registry: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param ref_registry: Registry of named schemas resolved recursively. + :return: Registry + """ + if ref_registry is None: + # Retrieve external schemas for backward compatibility + ref_registry = Registry(retrieve=_retrieve_via_httpx) + if schema.references is not None: + for ref in schema.references: + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) + resource = Resource.from_contents( + referenced_schema_dict, default_specification=DEFAULT_SPEC) + ref_registry = ref_registry.with_resource(ref.name, resource) + return ref_registry + + + +class JSONSerializer(BaseSerializer): + """ + Serializer that outputs JSON encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Notes: + The ``title`` annotation, referred to elsewhere as a record name + is not strictly required by the JSON Schema specification. It is + however required by this serializer in order to register the schema + with Confluent Schema Registry. + + Prior to serialization, all objects must first be converted to + a dict instance. This may be handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with JSONSerializer. + + Args: + schema_str (str, Schema): + `JSON Schema definition. `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For + referencing other schemas, use a :py:class:`Schema` instance. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. + Converts object to a dict. + + conf (dict): JsonSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', + '_schema', '_schema_id', '_schema_name', '_to_dict', + '_parsed_schemas', '_validators', '_validate', '_json_encode'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + def __init__( + self, + schema_str: Union[str, Schema, None], + schema_registry_client: SchemaRegistryClient, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_encode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + self._schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + self._schema = schema_str + else: + self._schema = None + + self._json_encode = json_encode or json.dumps + self._registry = schema_registry_client + self._rule_registry = ( + rule_registry if rule_registry else RuleRegistry.get_global_instance() + ) + self._schema_id = None + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + schema_dict, ref_registry = self._get_parsed_schema(self._schema) + if schema_dict: + schema_name = schema_dict.get('title', None) + else: + schema_name = None + + self._schema_name = schema_name + self._parsed_schema = schema_dict + self._ref_registry = ref_registry + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to JSON, prepending it with Confluent Schema Registry + framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError if any error occurs serializing obj. + + Returns: + bytes: None if obj is None, else a byte array containing the JSON + serialized data with Confluent Schema Registry framing. + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + registered_schema = self._registry.lookup_schema(subject, + self._schema, + self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + schema = latest_schema.schema + parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) + else: + schema = self._schema + parsed_schema, ref_registry = self._parsed_schema, self._ref_registry + + if self._validate: + try: + validator = self._get_validator(schema, parsed_schema, ref_registry) + validator.validate(value) + except ValidationError as ve: + raise SerializationError(ve.message) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) + # JSON dump always writes a str never bytes + # https://docs.python.org/3/library/json.html + encoded_value = self._json_encode(value) + if isinstance(encoded_value, str): + encoded_value = encoded_value.encode("utf8") + fo.write(encoded_value) + + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator + + +class JSONDeserializer(BaseDeserializer): + """ + Deserializer for JSON encoded data with Confluent Schema Registry + framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Args: + schema_str (str, Schema, optional): + `JSON schema definition `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. If not provided, schemas will be + retrieved from schema_registry_client based on the schema ID in the + wire header of each message. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to a Python object instance. + + schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. + """ # noqa: E501 + + __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', + '_parsed_schemas', '_validators', '_validate', '_json_decode'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + def __init__( + self, + schema_str: Union[str, Schema, None], + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + schema_registry_client: Optional[SchemaRegistryClient] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_decode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + schema = schema_str + if bool(schema.references) and schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") + elif schema_str is None: + if schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is not provided""" + ) + schema = schema_str + else: + raise TypeError('You must pass either str or Schema') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + self._json_decode = json_decode or json.loads + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._validate, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) + else: + self._reader_schema, self._ref_registry = None, None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature" + " from_dict(dict, SerializationContext) -> object") + + self._from_dict = from_dict + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a JSON encoded record with Confluent Schema Registry framing to + a dict, or object instance according to from_dict if from_dict is specified. + + Args: + data (bytes): A JSON serialized record with Confluent Schema Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization operation. + + Returns: + A dict, or object instance according to from_dict if from_dict is specified. + + Raises: + SerializerError: If there was an error reading the Confluent framing data, or + if ``data`` was not successfully validated with the configured schema. + """ + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + + if self._registry is not None: + writer_schema_raw = self._registry.get_schema(schema_id) + writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("title")) + if subject is not None: + latest_schema = self._get_reader_schema(subject) + else: + writer_schema_raw = None + writer_schema, writer_ref_registry = None, None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema, reader_ref_registry = writer_schema, writer_ref_registry + + if migrations: + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) + + if self._validate: + try: + validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) + validator.validate(obj_dict) + except ValidationError as ve: + raise SerializationError(ve.message) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py new file mode 100644 index 000000000..7a263d124 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020-2022 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import struct +import warnings +from typing import Set, List, Union, Optional, Tuple + +from google.protobuf import json_format, descriptor_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.descriptor import Descriptor, FileDescriptor +from google.protobuf.message import DecodeError, Message +from google.protobuf.message_factory import GetMessageClass + +from confluent_kafka.schema_registry.common import (_MAGIC_BYTE, _ContextStringIO, + reference_subject_name_strategy, + topic_subject_name_strategy) +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry import (Schema, + SchemaReference, + RuleMode) +from confluent_kafka.serialization import SerializationError, \ + SerializationContext + +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache + + +def _resolve_named_schema( + schema: Schema, + schema_registry_client: SchemaRegistryClient, + pool: DescriptorPool, + visited: Optional[Set[str]] = None +): + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param pool: DescriptorPool to add resolved schemas to. + :return: DescriptorPool + """ + if visited is None: + visited = set() + if schema.references is not None: + for ref in schema.references: + if _is_builtin(ref.name) or ref.name in visited: + continue + visited.add(ref.name) + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + pool.Add(file_descriptor_proto) + + + +class ProtobufSerializer(BaseSerializer): + """ + Serializer for Protobuf Message derived classes. Serialization format is Protobuf, + with Confluent Schema Registry framing. + + Configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +=====================================+==========+======================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether or not to skip known types when resolving | + | ``skip.known.types`` | bool | schema dependencies. | + | | | | + | | | Defaults to True. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | + | | | references are constructed. | + | | | | + | | | Defaults to reference_subject_name_strategy | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | + | | | serialize message indexes without zig-zag encoding. | + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If the consumers of the topic being produced to are | + | | | using confluent-kafka-python <1.8 then this property | + | | | must be set to True until all old consumers have | + | | | have been upgraded. | + | | | | + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Args: + msg_type (Message): Protobuf Message type. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + conf (dict): ProtobufSerializer configuration. + + See Also: + `Protobuf API reference `_ + """ # noqa: E501 + __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', + '_schema', '_schema_id', '_ref_reference_subject_func', + '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'skip.known.types': True, + 'subject.name.strategy': topic_subject_name_strategy, + 'reference.subject.name.strategy': reference_subject_name_strategy, + 'use.deprecated.format': False, + } + + def __init__( + self, + msg_type: Message, + schema_registry_client: SchemaRegistryClient, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufSerializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._skip_known_types = conf_copy.pop('skip.known.types') + if not isinstance(self._skip_known_types, bool): + raise ValueError("skip.known.types must be a boolean value") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2021 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._ref_reference_subject_func = conf_copy.pop( + 'reference.subject.name.strategy') + if not callable(self._ref_reference_subject_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._schema_id = None + self._known_subjects = set() + self._msg_class = msg_type + self._parsed_schemas = ParsedSchemaCache() + + descriptor = msg_type.DESCRIPTOR + self._index_array = _create_index_array(descriptor) + self._schema = Schema(_schema_to_str(descriptor.file), + schema_type='PROTOBUF') + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): + """ + Writes val to buf, either using zigzag or uvarint encoding. + + Args: + buf (BytesIO): buffer to write to. + val (int): integer to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + if zigzag: + val = (val << 1) ^ (val >> 63) + + while (val & ~0x7f) != 0: + buf.write(_bytes((val & 0x7f) | 0x80)) + val >>= 7 + buf.write(_bytes(val)) + + @staticmethod + def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): + """ + Encodes each int as a uvarint onto buf + + Args: + buf (BytesIO): buffer to write to. + ints ([int]): ints to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + assert len(ints) > 0 + # The root element at the 0 position does not need a length prefix. + if ints == [0]: + buf.write(_bytes(0x00)) + return + + ProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) + + for value in ints: + ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) + + def _resolve_dependencies( + self, ctx: SerializationContext, + file_desc: FileDescriptor + ) -> List[SchemaReference]: + """ + Resolves and optionally registers schema references recursively. + + Args: + ctx (SerializationContext): Serialization context. + + file_desc (FileDescriptor): file descriptor to traverse. + """ + + schema_refs = [] + for dep in file_desc.dependencies: + if self._skip_known_types and _is_builtin(dep.name): + continue + dep_refs = self._resolve_dependencies(ctx, dep) + subject = self._ref_reference_subject_func(ctx, dep) + schema = Schema(_schema_to_str(dep), + references=dep_refs, + schema_type='PROTOBUF') + if self._auto_register: + self._registry.register_schema(subject, schema) + + reference = self._registry.lookup_schema(subject, schema) + # schema_refs are per file descriptor + schema_refs.append(SchemaReference(dep.name, + subject, + reference.version)) + return schema_refs + + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(message, ctx) + + def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an instance of a class derived from Protobuf Message, and prepends + it with Confluent Schema Registry framing. + + Args: + message (Message): An instance of a class derived from Protobuf Message. + + ctx (SerializationContext): Metadata relevant to the serialization. + operation. + + Raises: + SerializerError if any error occurs during serialization. + + Returns: + None if messages is None, else a byte array containing the Protobuf + serialized message with Confluent Schema Registry framing. + """ + + if message is None: + return None + + if not isinstance(message, self._msg_class): + raise ValueError("message must be of type {} not {}" + .format(self._msg_class, type(message))) + + subject = self._subject_name_func(ctx, + message.DESCRIPTOR.full_name) + latest_schema = self._get_reader_schema(subject, fmt='serialized') + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + references = self._resolve_dependencies( + ctx, message.DESCRIPTOR.file) + self._schema = Schema( + self._schema.schema_str, + self._schema.schema_type, + references + ) + + if self._auto_register: + self._schema_id = self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + self._schema_id = self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas).schema_id + + self._known_subjects.add(subject) + + if latest_schema is not None: + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + fd = pool.FindFileByName(fd_proto.name) + desc = fd.message_types_by_name[message.DESCRIPTOR.name] + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, desc, msg, field_transform)) + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order + # (big endian) + fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the index array that specifies the message descriptor + # of the serialized data. + self._encode_varints(fo, self._index_array, + zigzag=not self._use_deprecated_format) + # write the serialized data itself + fo.write(message.SerializeToString()) + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + + +class ProtobufDeserializer(BaseDeserializer): + """ + Deserializer for Protobuf serialized data with Confluent Schema Registry framing. + + Args: + message_type (Message derived type): Protobuf Message type. + conf (dict): Configuration dictionary. + + ProtobufDeserializer configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka. schema_registry | + | | | namespace . | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | + | | | deserialize message indexes without zig-zag encoding.| + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If Protobuf messages in the topic to consume were | + | | | produced with confluent-kafka-python <1.8 then this | + | | | property must be set to True until all old messages | + | | | have been processed and producers have been upgraded.| + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + + See Also: + `Protobuf API reference `_ + """ + + __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'use.deprecated.format': False, + } + + def __init__( + self, + message_type: Message, + conf: Optional[dict] = None, + schema_registry_client: Optional[SchemaRegistryClient] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + # Require use.deprecated.format to be explicitly configured + # during a transitionary period since old/new format are + # incompatible. + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufDeserializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2022 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + descriptor = message_type.DESCRIPTOR + self._msg_class = GetMessageClass(descriptor) + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: + """ + Decodes a single varint from a buffer. + + Args: + buf (BytesIO): buffer to read from + zigzag (bool): decode as zigzag or uvarint + + Returns: + int: decoded varint + + Raises: + EOFError: if buffer is empty + """ + + value = 0 + shift = 0 + try: + while True: + i = ProtobufDeserializer._read_byte(buf) + + value |= (i & 0x7f) << shift + shift += 7 + if not (i & 0x80): + break + + if zigzag: + value = (value >> 1) ^ -(value & 1) + + return value + + except EOFError: + raise EOFError("Unexpected EOF while reading index") + + @staticmethod + def _read_byte(buf: io.BytesIO) -> int: + """ + Read one byte from buf as an int. + + Args: + buf (BytesIO): The buffer to read from. + + .. _ord: + https://docs.python.org/2/library/functions.html#ord + """ + + i = buf.read(1) + if i == b'': + raise EOFError("Unexpected EOF encountered") + return ord(i) + + @staticmethod + def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: + """ + Read an index array from buf that specifies the message + descriptor of interest in the file descriptor. + + Args: + buf (BytesIO): The buffer to read from. + + Returns: + list of int: The index array. + """ + + size = ProtobufDeserializer._decode_varint(buf, zigzag=zigzag) + if size < 0 or size > 100000: + raise DecodeError("Invalid Protobuf msgidx array length") + + if size == 0: + return [0] + + msg_index = [] + for _ in range(size): + msg_index.append(ProtobufDeserializer._decode_varint(buf, + zigzag=zigzag)) + + return msg_index + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a serialized protobuf message with Confluent Schema Registry + framing. + + Args: + data (bytes): Serialized protobuf message with Confluent Schema + Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Returns: + Message: Protobuf Message instance. + + Raises: + SerializerError: If there was an error reading the Confluent framing + data, or parsing the protobuf serialized message. + """ + + if data is None: + return None + + # SR wire protocol + msg_index length + if len(data) < 6: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unknown magic byte. This message was " + "not produced with a Confluent " + "Schema Registry serializer") + + msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) + + if self._registry is not None: + writer_schema_raw = self._registry.get_schema(schema_id, fmt='serialized') + fd_proto, pool = self._get_parsed_schema(writer_schema_raw) + writer_schema = pool.FindFileByName(fd_proto.name) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + if subject is None: + subject = self._subject_name_func(ctx, writer_desc.full_name) + if subject is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + else: + writer_schema_raw = None + writer_schema = None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + reader_schema = pool.FindFileByName(fd_proto.name) + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if reader_schema is not None: + # Initialize reader desc to first message in file + reader_desc = self._get_message_desc(pool, reader_schema, [0]) + # Attempt to find a reader desc with the same name as the writer + reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) + + if migrations: + msg = GetMessageClass(writer_desc)() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + obj_dict = json_format.MessageToDict(msg, True) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + msg = GetMessageClass(reader_desc)() + msg = json_format.ParseDict(obj_dict, msg) + else: + # Protobuf Messages are self-describing; no need to query schema + msg = self._msg_class() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_desc, message, field_transform)) + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) + + return msg + + def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + def _get_message_desc( + self, pool: DescriptorPool, fd: FileDescriptor, + msg_index: List[int] + ) -> Descriptor: + file_desc_proto = descriptor_pb2.FileDescriptorProto() + fd.CopyToProto(file_desc_proto) + (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) + package = file_desc_proto.package + qualified_name = package + "." + full_name if package else full_name + return pool.FindMessageTypeByName(qualified_name) + + def _get_message_desc_proto( + self, + path: str, + desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], + msg_index: List[int] + ) -> Tuple[str, descriptor_pb2.DescriptorProto]: + index = msg_index[0] + if isinstance(desc, descriptor_pb2.FileDescriptorProto): + msg = desc.message_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) + else: + msg = desc.nested_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py new file mode 100644 index 000000000..1f86dbf1c --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import time +import urllib +from urllib.parse import unquote, urlparse + +import httpx +from typing import List, Dict, Optional, Union, Any, Tuple, Callable + +from cachetools import TTLCache, LRUCache +from httpx import Response + +from authlib.integrations.httpx_client import OAuth2Client + +from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError +from confluent_kafka.schema_registry.common.schema_registry_client import ( + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, + _BearerFieldProvider, + full_jitter, + _SchemaCache, + Schema, +) + +# TODO: consider adding `six` dependency or employing a compat file +# Python 2.7 is officially EOL so compatibility issue will be come more the norm. +# We need a better way to handle these issues. +# Six is one possibility but the compat file pattern used by requests +# is also quite nice. +# +# six: https://pypi.org/project/six/ +# compat file : https://github.com/psf/requests/blob/master/requests/compat.py +try: + string_type = basestring # noqa + + def _urlencode(value: str) -> str: + return urllib.quote(value, safe='') +except NameError: + string_type = str + + def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') + +log = logging.getLogger(__name__) + + +class _StaticFieldProvider(_BearerFieldProvider): + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} + + +class _CustomOAuthClient(_BearerFieldProvider): + def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): + self.custom_function = custom_function + self.custom_config = custom_config + + def get_bearer_fields(self) -> dict: + return self.custom_function(self.custom_config) + + +class _OAuthClient(_BearerFieldProvider): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self.token = None + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint = token_endpoint + self.max_retries = max_retries + self.retries_wait_ms = retries_wait_ms + self.retries_max_wait_ms = retries_max_wait_ms + self.token_expiry_threshold = 0.8 + + def get_bearer_fields(self) -> dict: + return { + 'bearer.auth.token': self.get_access_token(), + 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool + } + + def token_expired(self) -> bool: + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window + + def get_access_token(self) -> str: + if not self.token or self.token_expired(): + self.generate_access_token() + + return self.token['access_token'] + + def generate_access_token(self) -> None: + for i in range(self.max_retries + 1): + try: + self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return + except Exception as e: + if i >= self.max_retries: + raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " + f"attempts due to error: {str(e)}") + time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + + +class _BaseRestClient(object): + + def __init__(self, conf: dict): + # copy dict to avoid mutating the original + conf_copy = conf.copy() + + base_url = conf_copy.pop('url', None) + if base_url is None: + raise ValueError("Missing required configuration property url") + if not isinstance(base_url, string_type): + raise TypeError("url must be a str, not " + str(type(base_url))) + base_urls = [] + for url in base_url.split(','): + url = url.strip().rstrip('/') + if not url.startswith('http') and not url.startswith('mock'): + raise ValueError("Invalid url {}".format(url)) + base_urls.append(url) + if not base_urls: + raise ValueError("Missing required configuration property url") + self.base_urls = base_urls + + self.verify = True + ca = conf_copy.pop('ssl.ca.location', None) + if ca is not None: + self.verify = ca + + key: Optional[str] = conf_copy.pop('ssl.key.location', None) + client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) + self.cert: Union[str, Tuple[str, str], None] = None + + if client_cert is not None and key is not None: + self.cert = (client_cert, key) + + if client_cert is not None and key is None: + self.cert = client_cert + + if key is not None and client_cert is None: + raise ValueError("ssl.certificate.location required when" + " configuring ssl.key.location") + + parsed = urlparse(self.base_urls[0]) + try: + userinfo = (unquote(parsed.username), unquote(parsed.password)) + except (AttributeError, TypeError): + userinfo = ("", "") + if 'basic.auth.user.info' in conf_copy: + if userinfo != ('', ''): + raise ValueError("basic.auth.user.info configured with" + " userinfo credentials in the URL." + " Remove userinfo credentials from the url or" + " remove basic.auth.user.info from the" + " configuration") + + userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) + + if len(userinfo) != 2: + raise ValueError("basic.auth.user.info must be in the form" + " of {username}:{password}") + + self.auth = userinfo if userinfo != ('', '') else None + + # The following adds support for proxy config + # If specified: it uses the specified proxy details when making requests + self.proxy = None + proxy = conf_copy.pop('proxy', None) + if proxy is not None: + self.proxy = proxy + + self.timeout = None + timeout = conf_copy.pop('timeout', None) + if timeout is not None: + self.timeout = timeout + + self.cache_capacity = 1000 + cache_capacity = conf_copy.pop('cache.capacity', None) + if cache_capacity is not None: + if not isinstance(cache_capacity, (int, float)): + raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) + self.cache_capacity = cache_capacity + + self.cache_latest_ttl_sec = None + cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) + if cache_latest_ttl_sec is not None: + if not isinstance(cache_latest_ttl_sec, (int, float)): + raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) + self.cache_latest_ttl_sec = cache_latest_ttl_sec + + self.max_retries = 3 + max_retries = conf_copy.pop('max.retries', None) + if max_retries is not None: + if not isinstance(max_retries, (int, float)): + raise TypeError("max.retries must be a number, not " + str(type(max_retries))) + self.max_retries = max_retries + + self.retries_wait_ms = 1000 + retries_wait_ms = conf_copy.pop('retries.wait.ms', None) + if retries_wait_ms is not None: + if not isinstance(retries_wait_ms, (int, float)): + raise TypeError("retries.wait.ms must be a number, not " + + str(type(retries_wait_ms))) + self.retries_wait_ms = retries_wait_ms + + self.retries_max_wait_ms = 20000 + retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) + if retries_max_wait_ms is not None: + if not isinstance(retries_max_wait_ms, (int, float)): + raise TypeError("retries.max.wait.ms must be a number, not " + + str(type(retries_max_wait_ms))) + self.retries_max_wait_ms = retries_max_wait_ms + + self.bearer_field_provider = None + logical_cluster = None + identity_pool = None + self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) + if self.bearer_auth_credentials_source is not None: + self.auth = None + + if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: + headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + missing_headers = [header for header in headers if header not in conf_copy] + if missing_headers: + raise ValueError("Missing required bearer configuration properties: {}" + .format(", ".join(missing_headers))) + + logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') + if not isinstance(logical_cluster, str): + raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) + + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') + if not isinstance(identity_pool, str): + raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + elif self.bearer_auth_credentials_source == 'CUSTOM': + custom_bearer_properties = ['bearer.auth.custom.provider.function', + 'bearer.auth.custom.provider.config'] + missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] + if missing_custom_properties: + raise ValueError("Missing required custom OAuth configuration properties: {}". + format(", ".join(missing_custom_properties))) + + custom_function = conf_copy.pop('bearer.auth.custom.provider.function') + if not callable(custom_function): + raise TypeError("bearer.auth.custom.provider.function must be a callable, not " + + str(type(custom_function))) + + custom_config = conf_copy.pop('bearer.auth.custom.provider.config') + if not isinstance(custom_config, dict): + raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + + str(type(custom_config))) + + self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) + else: + raise ValueError('Unrecognized bearer.auth.credentials.source') + + # Any leftover keys are unknown to _RestClient + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + def get(self, url: str, query: Optional[dict] = None) -> Any: + raise NotImplementedError() + + def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + raise NotImplementedError() + + def delete(self, url: str) -> Any: + raise NotImplementedError() + + def put(self, url: str, body: Optional[dict] = None) -> Any: + raise NotImplementedError() + + +class _RestClient(_BaseRestClient): + """ + HTTP client for Confluent Schema Registry. + + See SchemaRegistryClient for configuration details. + + Args: + conf (dict): Dictionary containing _RestClient configuration + """ + + def __init__(self, conf: dict): + super().__init__(conf) + + self.session = httpx.Client( + verify=self.verify, + cert=self.cert, + auth=self.auth, + proxy=self.proxy, + timeout=self.timeout + ) + + def handle_bearer_auth(self, headers: dict) -> None: + bearer_fields = self.bearer_field_provider.get_bearer_fields() + required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + + missing_fields = [] + for field in required_fields: + if field not in bearer_fields: + missing_fields.append(field) + + if missing_fields: + raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" + .format(", ".join(missing_fields))) + + headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + + def get(self, url: str, query: Optional[dict] = None) -> Any: + return self.send_request(url, method='GET', query=query) + + def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + return self.send_request(url, method='POST', body=body) + + def delete(self, url: str) -> Any: + return self.send_request(url, method='DELETE') + + def put(self, url: str, body: Optional[dict] = None) -> Any: + return self.send_request(url, method='PUT', body=body) + + def send_request( + self, url: str, method: str, body: Optional[dict] = None, + query: Optional[dict] = None + ) -> Any: + """ + Sends HTTP request to the SchemaRegistry, trying each base URL in turn. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + url (str): Request path + + method (str): HTTP method + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + dict: Schema Registry response content. + """ + + headers = {'Accept': "application/vnd.schemaregistry.v1+json," + " application/vnd.schemaregistry+json," + " application/json"} + + if body is not None: + body = json.dumps(body) + headers = {'Content-Length': str(len(body)), + 'Content-Type': "application/vnd.schemaregistry.v1+json"} + + if self.bearer_auth_credentials_source: + self.handle_bearer_auth(headers) + + response = None + for i, base_url in enumerate(self.base_urls): + try: + response = self.send_http_request( + base_url, url, method, headers, body, query) + + if is_success(response.status_code): + return response.json() + + if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: + break + except Exception as e: + if i == len(self.base_urls) - 1: + # Raise the exception since we have no more urls to try + raise e + + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + # Schema Registry may return malformed output when it hits unexpected errors + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + + def send_http_request( + self, base_url: str, url: str, method: str, headers: Optional[dict], + body: Optional[str] = None, query: Optional[dict] = None + ) -> Response: + """ + Sends HTTP request to the SchemaRegistry. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + base_url (str): Schema Registry base URL + + url (str): Request path + + method (str): HTTP method + + headers (dict): Headers + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + Response: Schema Registry response content. + """ + response = None + for i in range(self.max_retries + 1): + response = self.session.request( + method, url="/".join([base_url, url]), + headers=headers, content=body, params=query) + + if is_success(response.status_code): + return response + + if not is_retriable(response.status_code) or i >= self.max_retries: + return response + + time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + return response + + +class SchemaRegistryClient(object): + """ + A Confluent Schema Registry client. + + Configuration properties (* indicates a required field): + + +------------------------------+------+-------------------------------------------------+ + | Property name | type | Description | + +==============================+======+=================================================+ + | ``url`` * | str | Comma-separated list of Schema Registry URLs. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to CA certificate file used | + | ``ssl.ca.location`` | str | to verify the Schema Registry's | + | | | private key. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's private key | + | | | (PEM) used for authentication. | + | ``ssl.key.location`` | str | | + | | | ``ssl.certificate.location`` must also be set. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's public key (PEM) used for | + | | | authentication. | + | ``ssl.certificate.location`` | str | | + | | | May be set without ssl.key.location if the | + | | | private key is stored within the PEM as well. | + +------------------------------+------+-------------------------------------------------+ + | | | Client HTTP credentials in the form of | + | | | ``username:password``. | + | ``basic.auth.user.info`` | str | | + | | | By default userinfo is extracted from | + | | | the URL if present. | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``proxy`` | str | Proxy such as http://localhost:8030. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``timeout`` | int | Request timeout. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | Maximum time to wait for the first retry. | + | | | When jitter is applied, the actual wait may | + | ``retries.wait.ms`` | int | be less. | + | | | | + | | | Defaults to 1000. | + +------------------------------+------+-------------------------------------------------+ + + Args: + conf (dict): Schema Registry client configuration. + + See Also: + `Confluent Schema Registry documentation `_ + """ # noqa: E501 + + def __init__(self, conf: dict): + self._conf = conf + self._rest_client = _RestClient(conf) + self._cache = _SchemaCache() + cache_capacity = self._rest_client.cache_capacity + cache_ttl = self._rest_client.cache_latest_ttl_sec + if cache_ttl is not None: + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + else: + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) + + def __enter__(self): + return self + + def __exit__(self, *args): + if self._rest_client is not None: + self._rest_client.session.close() + + def config(self): + return self._conf + + def register_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> int: + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self.register_schema_full_response(subject_name, schema, normalize_schemas) + return registered_schema.schema_id + + def register_schema_full_response( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> 'RegisteredSchema': + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + schema_id = self._cache.get_id_by_schema(subject_name, schema) + if schema_id is not None: + return RegisteredSchema(schema_id, schema, subject_name, None) + + request = schema.to_dict() + + response = self._rest_client.post( + 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), + body=request) + + registered_schema = RegisteredSchema.from_dict(response) + + # The registered schema may not be fully populated + self._cache.set_schema(subject_name, registered_schema.schema_id, schema) + + return registered_schema + + def get_schema( + self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``schema_id`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + schema_id (int): Schema id + subject_name (str): Subject name the schema is registered under + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``schema_id`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + schema = self._cache.get_schema_by_id(subject_name, schema_id) + if schema is not None: + return schema + + query = {'subject': subject_name} if subject_name is not None else None + if fmt is not None: + if query is not None: + query['format'] = fmt + else: + query = {'format': fmt} + response = self._rest_client.get('schemas/ids/{}'.format(schema_id), query) + + schema = Schema.from_dict(response) + + self._cache.set_schema(subject_name, schema_id, schema) + + return schema + + def lookup_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False, deleted: bool = False + ) -> 'RegisteredSchema': + """ + Returns ``schema`` registration information for ``subject``. + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + normalize_schemas (bool): Normalize schema before registering + deleted (bool): Whether to include deleted schemas. + + Returns: + RegisteredSchema: Subject registration information for this schema. + + Raises: + SchemaRegistryError: If schema or subject can't be found + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) + if registered_schema is not None: + return registered_schema + + request = schema.to_dict() + + response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' + .format(_urlencode(subject_name), normalize_schemas, deleted), + body=request) + + result = RegisteredSchema.from_dict(response) + + # Ensure the schema matches the input + registered_schema = RegisteredSchema( + schema_id=result.schema_id, + subject=result.subject, + version=result.version, + schema=schema, + ) + + self._cache.set_registered_schema(schema, registered_schema) + + return registered_schema + + def get_subjects(self) -> List[str]: + """ + List all subjects registered with the Schema Registry + + Returns: + list(str): Registered subject names + + Raises: + SchemaRegistryError: if subjects can't be found + + See Also: + `GET subjects API Reference `_ + """ # noqa: E501 + + return self._rest_client.get('subjects') + + def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: + """ + Deletes the specified subject and its associated compatibility level if + registered. It is recommended to use this API only when a topic needs + to be recycled or in development environments. + + Args: + subject_name (str): subject name + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + list(int): Versions deleted under this subject + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `DELETE Subject API Reference `_ + """ # noqa: E501 + + if permanent: + versions = self._rest_client.delete('subjects/{}?permanent=true' + .format(_urlencode(subject_name))) + self._cache.remove_by_subject(subject_name) + else: + versions = self._rest_client.delete('subjects/{}' + .format(_urlencode(subject_name))) + + return versions + + def get_latest_version( + self, subject_name: str, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject + + Args: + subject_name (str): Subject name. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._latest_version_cache.get(subject_name, None) + if registered_schema is not None: + return registered_schema + + query = {'format': fmt} if fmt is not None else None + response = self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + 'latest'), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_version_cache[subject_name] = registered_schema + + return registered_schema + + def get_latest_with_metadata( + self, subject_name: str, metadata: Dict[str, str], + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject with the given metadata + + Args: + subject_name (str): Subject name. + metadata (dict): The key-value pairs for the metadata. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + """ # noqa: E501 + + cache_key = (subject_name, frozenset(metadata.items()), deleted) + registered_schema = self._latest_with_metadata_cache.get(cache_key, None) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + keys = metadata.keys() + if keys: + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] + + response = self._rest_client.get('subjects/{}/metadata' + .format(_urlencode(subject_name)), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_with_metadata_cache[cache_key] = registered_schema + + return registered_schema + + def get_version( + self, subject_name: str, version: int, + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves a specific schema registered under ``subject_name``. + + Args: + subject_name (str): Subject name. + version (int): version number. Defaults to latest version. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + response = self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_registered_schema(registered_schema.schema, registered_schema) + + return registered_schema + + def get_versions(self, subject_name: str) -> List[int]: + """ + Get a list of all versions registered with this subject. + + Args: + subject_name (str): Subject name. + + Returns: + list(int): Registered versions + + Raises: + SchemaRegistryError: If subject can't be found + + See Also: + `GET Subject Versions API Reference `_ + """ # noqa: E501 + + return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) + + def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: + """ + Deletes a specific version registered to ``subject_name``. + + Args: + subject_name (str) Subject name + + version (int): Version number + + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + int: Version number which was deleted + + Raises: + SchemaRegistryError: if the subject or version cannot be found. + + See Also: + `Delete Subject Version API Reference `_ + """ # noqa: E501 + + if permanent: + response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' + .format(_urlencode(subject_name), + version)) + self._cache.remove_by_subject_version(subject_name, version) + else: + response = self._rest_client.delete('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version)) + + return response + + def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: + """ + Update global or subject level compatibility level. + + Args: + level (str): Compatibility level. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global compatibility + level policy if not set. + + Returns: + str: The newly configured compatibility level. + + Raises: + SchemaRegistryError: If the compatibility level is invalid. + + See Also: + `PUT Subject Compatibility API Reference `_ + """ # noqa: E501 + + if level is None: + raise ValueError("level must be set") + + if subject_name is None: + return self._rest_client.put('config', + body={'compatibility': level.upper()}) + + return self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body={'compatibility': level.upper()}) + + def get_compatibility(self, subject_name: Optional[str] = None) -> str: + """ + Get the current compatibility level. + + Args: + subject_name (str, optional): Subject name. Returns global policy + if left unset. + + Returns: + str: Compatibility level for the subject if set, otherwise the global compatibility level. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Compatibility API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = self._rest_client.get(url) + return result['compatibilityLevel'] + + def test_compatibility( + self, subject_name: str, schema: 'Schema', + version: Union[int, str] = "latest" + ) -> bool: + """Test the compatibility of a candidate schema for a given subject and version + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + version (int or str, optional): Version number, or the string "latest". Defaults to "latest". + + Returns: + bool: True if the schema is compatible with the specified version + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `POST Test Compatibility API Reference `_ + """ # noqa: E501 + + request = schema.to_dict() + + response = self._rest_client.post( + 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request + ) + + return response['is_compatible'] + + def set_config( + self, subject_name: Optional[str] = None, + config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': + """ + Update global or subject config. + + Args: + config (ServerConfig): Config. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global config + if not set. + + Returns: + str: The newly configured config. + + Raises: + SchemaRegistryError: If the config is invalid. + + See Also: + `PUT Subject Config API Reference `_ + """ # noqa: E501 + + if config is None: + raise ValueError("config must be set") + + if subject_name is None: + return self._rest_client.put('config', + body=config.to_dict()) + + return self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body=config.to_dict()) + + def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': + """ + Get the current config. + + Args: + subject_name (str, optional): Subject name. Returns global config + if left unset. + + Returns: + ServerConfig: Config for the subject if set, otherwise the global config. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Config API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = self._rest_client.get(url) + return ServerConfig.from_dict(result) + + def clear_latest_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + + def clear_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + self._cache.clear() + + @staticmethod + def new_client(conf: dict) -> 'SchemaRegistryClient': + from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient + url = conf.get("url") + if url.startswith("mock://"): + return MockSchemaRegistryClient(conf) + return SchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py new file mode 100644 index 000000000..409e00160 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import List, Optional, Set, Dict, Any + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.serde import ErrorAction, FieldTransformer, Migration, NoneAction, RuleAction, RuleConditionError, RuleContext, RuleError +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema, RuleSet +from confluent_kafka.serialization import Serializer, Deserializer, \ + SerializationContext, SerializationError + +log = logging.getLogger(__name__) + +class BaseSerde(object): + __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', + '_registry', '_rule_registry', '_subject_name_func', + '_field_transformer'] + + def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: + if self._use_schema_id is not None: + schema = self._registry.get_schema(self._use_schema_id, subject, fmt) + return self._registry.lookup_schema(subject, schema, False, True) + if self._use_latest_with_metadata is not None: + return self._registry.get_latest_with_metadata( + subject, self._use_latest_with_metadata, True, fmt) + if self._use_latest_version: + return self._registry.get_latest_version(subject, fmt) + return None + + def _execute_rules( + self, ser_ctx: SerializationContext, subject: str, + rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] + ) -> Any: + if message is None or target is None: + return message + rules: Optional[List[Rule]] = None + if rule_mode == RuleMode.UPGRADE: + if target is not None and target.rule_set is not None: + rules = target.rule_set.migration_rules + elif rule_mode == RuleMode.DOWNGRADE: + if source is not None and source.rule_set is not None: + rules = source.rule_set.migration_rules + rules = rules[:] if rules else [] + rules.reverse() + else: + if target is not None and target.rule_set is not None: + rules = target.rule_set.domain_rules + if rule_mode == RuleMode.READ: + # Execute read rules in reverse order for symmetry + rules = rules[:] if rules else [] + rules.reverse() + + if not rules: + return message + + for index in range(len(rules)): + rule = rules[index] + if self._is_disabled(rule): + continue + if rule.mode == RuleMode.WRITEREAD: + if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: + continue + elif rule.mode == RuleMode.UPDOWN: + if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: + continue + elif rule.mode != rule_mode: + continue + + ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, + index, rules, inline_tags, field_transformer) + rule_executor = self._rule_registry.get_executor(rule.type.upper()) + if rule_executor is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Could not find rule executor of type {rule.type}"), + 'ERROR') + return message + try: + result = rule_executor.transform(ctx, message) + if rule.kind == RuleKind.CONDITION: + if not result: + raise RuleConditionError(rule) + elif rule.kind == RuleKind.TRANSFORM: + message = result + self._run_action( + ctx, rule_mode, rule, + self._get_on_failure(rule) if message is None else self._get_on_success(rule), + message, None, + 'ERROR' if message is None else 'NONE') + except SerializationError: + raise + except Exception as e: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), + message, e, 'ERROR') + return message + + def _get_on_success(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_success is not None: + return override.on_success + return rule.on_success + + def _get_on_failure(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_failure is not None: + return override.on_failure + return rule.on_failure + + def _is_disabled(self, rule: Rule) -> Optional[bool]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.disabled is not None: + return override.disabled + return rule.disabled + + def _run_action( + self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, + action: Optional[str], message: Any, + ex: Optional[Exception], default_action: str + ): + action_name = self._get_rule_action_name(rule, rule_mode, action) + if action_name is None: + action_name = default_action + rule_action = self._get_rule_action(ctx, action_name) + if rule_action is None: + log.error("Could not find rule action of type %s", action_name) + raise RuleError(f"Could not find rule action of type {action_name}") + try: + rule_action.run(ctx, message, ex) + except SerializationError: + raise + except Exception as e: + log.warning("Could not run post-rule action %s: %s", action_name, e) + + def _get_rule_action_name( + self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] + ) -> Optional[str]: + if action_name is None or action_name == "": + return None + if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: + parts = action_name.split(',') + if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + return parts[0] + elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): + return parts[1] + return action_name + + def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: + if action_name == 'ERROR': + return ErrorAction() + elif action_name == 'NONE': + return NoneAction() + return self._rule_registry.get_action(action_name) + + +class BaseSerializer(BaseSerde, Serializer): + __slots__ = ['_auto_register', '_normalize_schemas'] + + +class BaseDeserializer(BaseSerde, Deserializer): + __slots__ = [] + + def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + if rule_set is None: + return False + if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): + return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN + for rule in rule_set.migration_rules or []) + elif mode == RuleMode.UPDOWN: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + elif mode in (RuleMode.WRITE, RuleMode.READ): + return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD + for rule in rule_set.domain_rules or []) + elif mode == RuleMode.WRITEREAD: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return False + + def _get_migrations( + self, subject: str, source_info: Schema, + target: RegisteredSchema, fmt: Optional[str] + ) -> List[Migration]: + source = self._registry.lookup_schema(subject, source_info, False, True) + migrations = [] + if source.version < target.version: + migration_mode = RuleMode.UPGRADE + first = source + last = target + elif source.version > target.version: + migration_mode = RuleMode.DOWNGRADE + first = target + last = source + else: + return migrations + previous: Optional[RegisteredSchema] = None + versions = self._get_schemas_between(subject, first, last, fmt) + for i in range(len(versions)): + version = versions[i] + if i == 0: + previous = version + continue + if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) + previous = version + if migration_mode == RuleMode.DOWNGRADE: + migrations.reverse() + return migrations + + def _get_schemas_between( + self, subject: str, first: RegisteredSchema, + last: RegisteredSchema, fmt: Optional[str] = None + ) -> List[RegisteredSchema]: + if last.version - first.version <= 1: + return [first, last] + version1 = first.version + version2 = last.version + result = [first] + for i in range(version1 + 1, version2): + result.append(self._registry.get_version(subject, i, True, fmt)) + result.append(last) + return result + + def _execute_migrations( + self, ser_ctx: SerializationContext, subject: str, + migrations: List[Migration], message: Any + ) -> Any: + for migration in migrations: + message = self._execute_rules(ser_ctx, subject, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) + return message diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 8f6cddf8b..103acd58c 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -15,796 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import decimal -import re -from collections import defaultdict -from copy import deepcopy -from io import BytesIO -from json import loads -from struct import pack, unpack -from typing import Dict, Union, Optional, Set, Callable - -from fastavro import (schemaless_reader, - schemaless_writer, - repository, - validate) -from fastavro.schema import load_schema - -from . import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleMode, - RuleKind, SchemaRegistryClient) -from confluent_kafka.serialization import (SerializationError, - SerializationContext) -from .rule_registry import RuleRegistry -from .serde import BaseSerializer, BaseDeserializer, RuleContext, FieldType, \ - FieldTransform, RuleConditionError, ParsedSchemaCache - - -AvroMessage = Union[ - None, # 'null' Avro type - str, # 'string' and 'enum' - float, # 'float' and 'double' - int, # 'int' and 'long' - decimal.Decimal, # 'fixed' - bool, # 'boolean' - bytes, # 'bytes' - list, # 'array' - dict, # 'map' and 'record' -] -AvroSchema = Union[str, list, dict] - - -class _ContextStringIO(BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _schema_loads(schema_str: str) -> Schema: - """ - Instantiate a Schema instance from a declaration string. - - Args: - schema_str (str): Avro Schema declaration. - - .. _Schema declaration: - https://avro.apache.org/docs/current/spec.html#schemas - - Returns: - Schema: A Schema instance. - """ - - schema_str = schema_str.strip() - - # canonical form primitive declarations are not supported - if schema_str[0] != "{" and schema_str[0] != "[": - schema_str = '{"type":' + schema_str + '}' - - return Schema(schema_str, schema_type='AVRO') - - -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient -) -> Dict[str, AvroSchema]: - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :return: named_schemas dict. - """ - named_schemas = {} - if schema.references is not None: - for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) - parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) - named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema - return named_schemas - - -class AvroSerializer(BaseSerializer): - """ - Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. - - Configuration properties: - - +-----------------------------+----------+--------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+==================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+--------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies: - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Note: - Prior to serialization, all values must first be converted to - a dict instance. This may handled manually prior to calling - :py:func:`Producer.produce()` or by registering a `to_dict` - callable with AvroSerializer. - - See ``avro_producer.py`` in the examples directory for example usage. - - Note: - Tuple notation can be used to determine which branch of an ambiguous union to take. - - See `fastavro notation `_ - - Args: - schema_registry_client (SchemaRegistryClient): Schema Registry client instance. - - schema_str (str or Schema): - Avro `Schema Declaration. `_ - Accepts either a string or a :py:class:`Schema` instance. Note that string - definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. - - to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. - - conf (dict): AvroSerializer configuration. - """ # noqa: E501 - __slots__ = ['_known_subjects', '_parsed_schema', '_schema', - '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] - - _default_conf = {'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} - - def __init__( - self, - schema_registry_client: SchemaRegistryClient, - schema_str: Union[str, Schema, None] = None, - to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - if isinstance(schema_str, str): - schema = _schema_loads(schema_str) - elif isinstance(schema_str, Schema): - schema = schema_str - else: - schema = None - - self._registry = schema_registry_client - self._schema_id = None - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() - self._parsed_schemas = ParsedSchemaCache() - - if to_dict is not None and not callable(to_dict): - raise ValueError("to_dict must be callable with the signature " - "to_dict(object, SerializationContext)->dict") - - self._to_dict = to_dict - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - parsed_schema = self._get_parsed_schema(schema) - - if isinstance(parsed_schema, list): - # if parsed_schema is a list, we have an Avro union and there - # is no valid schema name. This is fine because the only use of - # schema_name is for supplying the subject name to the registry - # and union types should use topic_subject_name_strategy, which - # just discards the schema name anyway - schema_name = None - else: - # The Avro spec states primitives have a name equal to their type - # i.e. {"type": "string"} has a name of string. - # This function does not comply. - # https://github.com/fastavro/fastavro/issues/415 - schema_dict = loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) - else: - schema_name = None - parsed_schema = None - - self._schema = schema - self._schema_name = schema_name - self._parsed_schema = parsed_schema - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an object to Avro binary format, prepending it with Confluent - Schema Registry framing. - - Args: - obj (object): The object instance to serialize. - - ctx (SerializationContext): Metadata pertaining to the serialization operation. - - Raises: - SerializerError: If any error occurs serializing obj. - SchemaRegistryError: If there was an error registering the schema with - Schema Registry, or auto.register.schemas is - false and the schema was not registered. - - Returns: - bytes: Confluent Schema Registry encoded Avro bytes - """ - - if obj is None: - return None - - subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - # Check to ensure this schema has been registered under subject_name. - if self._auto_register: - # The schema name will always be the same. We can't however register - # a schema without a subject so we set the schema_id here to handle - # the initial registration. - self._schema_id = self._registry.register_schema( - subject, self._schema, self._normalize_schemas) - else: - registered_schema = self._registry.lookup_schema( - subject, self._schema, self._normalize_schemas) - self._schema_id = registered_schema.schema_id - - self._known_subjects.add(subject) - - if self._to_dict is not None: - value = self._to_dict(obj, ctx) - else: - value = obj - - if latest_schema is not None: - parsed_schema = self._get_parsed_schema(latest_schema.schema) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, get_inline_tags(parsed_schema), - field_transformer) - else: - parsed_schema = self._parsed_schema - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) - # write the record to the rest of the buffer - schemaless_writer(fo, parsed_schema, value) - - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> AvroSchema: - parsed_schema = self._parsed_schemas.get_parsed_schema(schema) - if parsed_schema is not None: - return parsed_schema - - named_schemas = _resolve_named_schema(schema, self._registry) - prepared_schema = _schema_loads(schema.schema_str) - parsed_schema = parse_schema_with_repo( - prepared_schema.schema_str, named_schemas=named_schemas) - - self._parsed_schemas.set(schema, parsed_schema) - return parsed_schema - - -class AvroDeserializer(BaseDeserializer): - """ - Deserializer for Avro binary encoded data with Confluent Schema Registry - framing. - - +-----------------------------+----------+--------------------------------------------------+ - | Property Name | Type | Description | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+--------------------------------------------------+ - - Note: - By default, Avro complex types are returned as dicts. This behavior can - be overridden by registering a callable ``from_dict`` with the deserializer to - convert the dicts to the desired type. - - See ``avro_consumer.py`` in the examples directory in the examples - directory for example usage. - - Args: - schema_registry_client (SchemaRegistryClient): Confluent Schema Registry - client instance. - - schema_str (str, Schema, optional): Avro reader schema declaration Accepts - either a string or a :py:class:`Schema` instance. If not provided, the - writer schema will be used as the reader schema. Note that string - definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. - - from_dict (callable, optional): Callable(dict, SerializationContext) -> object. - Converts a dict to an instance of some object. - - return_record_name (bool): If True, when reading a union of records, the result will - be a tuple where the first value is the name of the record and the second value is - the record itself. Defaults to False. - - See Also: - `Apache Avro Schema Declaration `_ - - `Apache Avro Schema Resolution `_ - """ - - __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', - '_schema', '_parsed_schemas'] - - _default_conf = {'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} - - def __init__( - self, - schema_registry_client: SchemaRegistryClient, - schema_str: Union[str, Schema, None] = None, - from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, - return_record_name: bool = False, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - schema = None - if schema_str is not None: - if isinstance(schema_str, str): - schema = _schema_loads(schema_str) - elif isinstance(schema_str, Schema): - schema = schema_str - else: - raise TypeError('You must pass either schema string or schema object') - - self._schema = schema - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._use_schema_id = None - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - self._reader_schema = self._get_parsed_schema(self._schema) - else: - self._reader_schema = None - - if from_dict is not None and not callable(from_dict): - raise ValueError("from_dict must be callable with the signature " - "from_dict(SerializationContext, dict) -> object") - self._from_dict = from_dict - - self._return_record_name = return_record_name - if not isinstance(self._return_record_name, bool): - raise ValueError("return_record_name must be a boolean value") - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: - """ - Deserialize Avro binary encoded data with Confluent Schema Registry framing to - a dict, or object instance according to from_dict, if specified. - - Arguments: - data (bytes): bytes - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Raises: - SerializerError: if an error occurs parsing data. - - Returns: - object: If data is None, then None. Else, a dict, or object instance according - to from_dict, if specified. - """ # noqa: E501 - - if data is None: - return None - - if len(data) <= 5: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) - latest_schema = None - if subject is not None: - latest_schema = self._get_reader_schema(subject) - - with _ContextStringIO(data) as payload: - magic, schema_id = unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - writer_schema_raw = self._registry.get_schema(schema_id) - writer_schema = self._get_parsed_schema(writer_schema_raw) - - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) - if subject is not None: - latest_schema = self._get_reader_schema(subject) - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema = self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema = self._reader_schema - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if migrations: - obj_dict = schemaless_reader(payload, - writer_schema, - None, - self._return_record_name) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - else: - obj_dict = schemaless_reader(payload, - writer_schema, - reader_schema, - self._return_record_name) - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) - - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) - - return obj_dict - - def _get_parsed_schema(self, schema: Schema) -> AvroSchema: - parsed_schema = self._parsed_schemas.get_parsed_schema(schema) - if parsed_schema is not None: - return parsed_schema - - named_schemas = _resolve_named_schema(schema, self._registry) - prepared_schema = _schema_loads(schema.schema_str) - parsed_schema = parse_schema_with_repo( - prepared_schema.schema_str, named_schemas=named_schemas) - - self._parsed_schemas.set(schema, parsed_schema) - return parsed_schema - - -class LocalSchemaRepository(repository.AbstractSchemaRepository): - def __init__(self, schemas): - self.schemas = schemas - - def load(self, subject): - return self.schemas.get(subject) - - -def parse_schema_with_repo(schema_str: str, named_schemas: Dict[str, AvroSchema]) -> AvroSchema: - copy = deepcopy(named_schemas) - copy["$root"] = loads(schema_str) - repo = LocalSchemaRepository(copy) - return load_schema("$root", repo=repo) - - -def transform( - ctx: RuleContext, schema: AvroSchema, message: AvroMessage, - field_transform: FieldTransform -) -> AvroMessage: - if message is None or schema is None: - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - field_ctx.field_type = get_type(schema) - if isinstance(schema, list): - subschema = _resolve_union(schema, message) - if subschema is None: - return message - return transform(ctx, subschema, message, field_transform) - elif isinstance(schema, dict): - schema_type = schema.get("type") - if schema_type == 'array': - return [transform(ctx, schema["items"], item, field_transform) - for item in message] - elif schema_type == 'map': - return {key: transform(ctx, schema["values"], value, field_transform) - for key, value in message.items()} - elif schema_type == 'record': - fields = schema["fields"] - for field in fields: - _transform_field(ctx, schema, field, message, field_transform) - return message - - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, schema: AvroSchema, field: dict, - message: AvroMessage, field_transform: FieldTransform -): - field_type = field["type"] - name = field["name"] - full_name = schema["name"] + "." + name - try: - ctx.enter_field( - message, - full_name, - name, - get_type(field_type), - None - ) - value = message[name] - new_value = transform(ctx, field_type, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - message[name] = new_value - finally: - ctx.exit_field() - - -def get_type(schema: AvroSchema) -> FieldType: - if isinstance(schema, list): - return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema - - if schema_type == 'record': - return FieldType.RECORD - elif schema_type == 'enum': - return FieldType.ENUM - elif schema_type == 'array': - return FieldType.ARRAY - elif schema_type == 'map': - return FieldType.MAP - elif schema_type == 'union': - return FieldType.COMBINED - elif schema_type == 'fixed': - return FieldType.FIXED - elif schema_type == 'string': - return FieldType.STRING - elif schema_type == 'bytes': - return FieldType.BYTES - elif schema_type == 'int': - return FieldType.INT - elif schema_type == 'long': - return FieldType.LONG - elif schema_type == 'float': - return FieldType.FLOAT - elif schema_type == 'double': - return FieldType.DOUBLE - elif schema_type == 'boolean': - return FieldType.BOOLEAN - elif schema_type == 'null': - return FieldType.NULL - else: - return FieldType.NULL - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]: - for subschema in schema: - try: - validate(message, subschema) - except: # noqa: E722 - continue - return subschema - return None - - -def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: - inline_tags = defaultdict(set) - _get_inline_tags_recursively('', '', schema, inline_tags) - return inline_tags - - -def _get_inline_tags_recursively( - ns: str, name: str, schema: Optional[AvroSchema], - tags: Dict[str, Set[str]] -): - if schema is None: - return - if isinstance(schema, list): - for subschema in schema: - _get_inline_tags_recursively(ns, name, subschema, tags) - elif not isinstance(schema, dict): - # string schemas; this could be either a named schema or a primitive type - return - else: - schema_type = schema.get("type") - if schema_type == 'array': - _get_inline_tags_recursively(ns, name, schema.get("items"), tags) - elif schema_type == 'map': - _get_inline_tags_recursively(ns, name, schema.get("values"), tags) - elif schema_type == 'record': - record_ns = schema.get("namespace") - record_name = schema.get("name") - if record_ns is None: - record_ns = _implied_namespace(name) - if record_ns is None: - record_ns = ns - if record_ns != '' and not record_name.startswith(record_ns): - record_name = f"{record_ns}.{record_name}" - fields = schema["fields"] - for field in fields: - field_tags = field.get("confluent:tags") - field_name = field.get("name") - field_type = field.get("type") - if field_tags is not None and field_name is not None: - tags[record_name + '.' + field_name].update(field_tags) - if field_type is not None: - _get_inline_tags_recursively(record_ns, record_name, field_type, tags) - - -def _implied_namespace(name: str) -> Optional[str]: - match = re.match(r"^(.*)\.[^.]+$", name) - return match.group(1) if match else None +from .common.avro import * +from ..avro import * +from ._sync.avro import * diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py new file mode 100644 index 000000000..c1e14957a --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/__init__.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from io import BytesIO +from typing import Optional + +from .schema_registry_client import SchemaReference + +_MAGIC_BYTE = 0 + +def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {topic}-key|value. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return ctx.topic + "-" + ctx.field + + +def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {topic}-{record_name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return ctx.topic + "-" + record_name if record_name is not None else None + + +def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {record_name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return record_name if record_name is not None else None + + +def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: + """ + Constructs a subject reference name in the form of {reference name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + schema_ref (SchemaReference): SchemaReference instance. + + """ + return schema_ref.name if schema_ref is not None else None + + +class _ContextStringIO(BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py new file mode 100644 index 000000000..a36e096c7 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -0,0 +1,233 @@ +import decimal +import re +from collections import defaultdict +from copy import deepcopy +from json import loads +from typing import Dict, Union, Optional, Set + +from fastavro import repository, validate +from fastavro.schema import load_schema + +from .schema_registry_client import Schema, RuleKind +from confluent_kafka.schema_registry.serde import RuleContext, FieldType, \ + FieldTransform, RuleConditionError + + +AvroMessage = Union[ + None, # 'null' Avro type + str, # 'string' and 'enum' + float, # 'float' and 'double' + int, # 'int' and 'long' + decimal.Decimal, # 'fixed' + bool, # 'boolean' + bytes, # 'bytes' + list, # 'array' + dict, # 'map' and 'record' +] +AvroSchema = Union[str, list, dict] + + +def _schema_loads(schema_str: str) -> Schema: + """ + Instantiate a Schema instance from a declaration string. + + Args: + schema_str (str): Avro Schema declaration. + + .. _Schema declaration: + https://avro.apache.org/docs/current/spec.html#schemas + + Returns: + Schema: A Schema instance. + """ + + schema_str = schema_str.strip() + + # canonical form primitive declarations are not supported + if schema_str[0] != "{" and schema_str[0] != "[": + schema_str = '{"type":' + schema_str + '}' + + return Schema(schema_str, schema_type='AVRO') + + +class LocalSchemaRepository(repository.AbstractSchemaRepository): + def __init__(self, schemas): + self.schemas = schemas + + def load(self, subject): + return self.schemas.get(subject) + + +def parse_schema_with_repo(schema_str: str, named_schemas: Dict[str, AvroSchema]) -> AvroSchema: + copy = deepcopy(named_schemas) + copy["$root"] = loads(schema_str) + repo = LocalSchemaRepository(copy) + return load_schema("$root", repo=repo) + + +def transform( + ctx: RuleContext, schema: AvroSchema, message: AvroMessage, + field_transform: FieldTransform +) -> AvroMessage: + if message is None or schema is None: + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + field_ctx.field_type = get_type(schema) + if isinstance(schema, list): + subschema = _resolve_union(schema, message) + if subschema is None: + return message + return transform(ctx, subschema, message, field_transform) + elif isinstance(schema, dict): + schema_type = schema.get("type") + if schema_type == 'array': + return [transform(ctx, schema["items"], item, field_transform) + for item in message] + elif schema_type == 'map': + return {key: transform(ctx, schema["values"], value, field_transform) + for key, value in message.items()} + elif schema_type == 'record': + fields = schema["fields"] + for field in fields: + _transform_field(ctx, schema, field, message, field_transform) + return message + + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, schema: AvroSchema, field: dict, + message: AvroMessage, field_transform: FieldTransform +): + field_type = field["type"] + name = field["name"] + full_name = schema["name"] + "." + name + try: + ctx.enter_field( + message, + full_name, + name, + get_type(field_type), + None + ) + value = message[name] + new_value = transform(ctx, field_type, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + message[name] = new_value + finally: + ctx.exit_field() + + +def get_type(schema: AvroSchema) -> FieldType: + if isinstance(schema, list): + return FieldType.COMBINED + elif isinstance(schema, dict): + schema_type = schema.get("type") + else: + # string schemas; this could be either a named schema or a primitive type + schema_type = schema + + if schema_type == 'record': + return FieldType.RECORD + elif schema_type == 'enum': + return FieldType.ENUM + elif schema_type == 'array': + return FieldType.ARRAY + elif schema_type == 'map': + return FieldType.MAP + elif schema_type == 'union': + return FieldType.COMBINED + elif schema_type == 'fixed': + return FieldType.FIXED + elif schema_type == 'string': + return FieldType.STRING + elif schema_type == 'bytes': + return FieldType.BYTES + elif schema_type == 'int': + return FieldType.INT + elif schema_type == 'long': + return FieldType.LONG + elif schema_type == 'float': + return FieldType.FLOAT + elif schema_type == 'double': + return FieldType.DOUBLE + elif schema_type == 'boolean': + return FieldType.BOOLEAN + elif schema_type == 'null': + return FieldType.NULL + else: + return FieldType.NULL + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]: + for subschema in schema: + try: + validate(message, subschema) + except: # noqa: E722 + continue + return subschema + return None + + +def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: + inline_tags = defaultdict(set) + _get_inline_tags_recursively('', '', schema, inline_tags) + return inline_tags + + +def _get_inline_tags_recursively( + ns: str, name: str, schema: Optional[AvroSchema], + tags: Dict[str, Set[str]] +): + if schema is None: + return + if isinstance(schema, list): + for subschema in schema: + _get_inline_tags_recursively(ns, name, subschema, tags) + elif not isinstance(schema, dict): + # string schemas; this could be either a named schema or a primitive type + return + else: + schema_type = schema.get("type") + if schema_type == 'array': + _get_inline_tags_recursively(ns, name, schema.get("items"), tags) + elif schema_type == 'map': + _get_inline_tags_recursively(ns, name, schema.get("values"), tags) + elif schema_type == 'record': + record_ns = schema.get("namespace") + record_name = schema.get("name") + if record_ns is None: + record_ns = _implied_namespace(name) + if record_ns is None: + record_ns = ns + if record_ns != '' and not record_name.startswith(record_ns): + record_name = f"{record_ns}.{record_name}" + fields = schema["fields"] + for field in fields: + field_tags = field.get("confluent:tags") + field_name = field.get("name") + field_type = field.get("type") + if field_tags is not None and field_name is not None: + tags[record_name + '.' + field_name].update(field_tags) + if field_type is not None: + _get_inline_tags_recursively(record_ns, record_name, field_type, tags) + + +def _implied_namespace(name: str) -> Optional[str]: + match = re.match(r"^(.*)\.[^.]+$", name) + return match.group(1) if match else None diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py new file mode 100644 index 000000000..2a147c09b --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -0,0 +1,167 @@ + +import decimal +from io import BytesIO + +from typing import Union, Optional, List, Set + +import httpx +import referencing +from jsonschema import validate, ValidationError +from referencing import Registry, Resource +from referencing._core import Resolver + +from confluent_kafka.schema_registry import RuleKind +from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, FieldType, \ + RuleConditionError + +JsonMessage = Union[ + None, # 'null' Avro type + str, # 'string' and 'enum' + float, # 'float' and 'double' + int, # 'int' and 'long' + decimal.Decimal, # 'fixed' + bool, # 'boolean' + list, # 'array' + dict, # 'map' and 'record' +] + +JsonSchema = Union[bool, dict] + +DEFAULT_SPEC = referencing.jsonschema.DRAFT7 + +def _retrieve_via_httpx(uri: str): + response = httpx.get(uri) + return Resource.from_contents( + response.json(), default_specification=DEFAULT_SPEC) + + +def transform( + ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, + path: str, message: JsonMessage, field_transform: FieldTransform +) -> Optional[JsonMessage]: + if message is None or schema is None or isinstance(schema, bool): + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + field_ctx.field_type = get_type(schema) + all_of = schema.get("allOf") + if all_of is not None: + subschema = _validate_subschemas(all_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + any_of = schema.get("anyOf") + if any_of is not None: + subschema = _validate_subschemas(any_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + one_of = schema.get("oneOf") + if one_of is not None: + subschema = _validate_subschemas(one_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + items = schema.get("items") + if items is not None: + if isinstance(message, list): + return [transform(ctx, items, ref_registry, ref_resolver, path, item, field_transform) for item in message] + ref = schema.get("$ref") + if ref is not None: + ref_schema = ref_resolver.lookup(ref) + return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) + schema_type = get_type(schema) + if schema_type == FieldType.RECORD: + props = schema.get("properties") + if props is not None: + for prop_name, prop_schema in props.items(): + _transform_field(ctx, path, prop_name, message, + prop_schema, ref_registry, ref_resolver, field_transform) + return message + if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, + prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform +): + full_name = path + "." + prop_name + try: + ctx.enter_field( + message, + full_name, + prop_name, + get_type(prop_schema), + get_inline_tags(prop_schema) + ) + value = message[prop_name] + new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + message[prop_name] = new_value + finally: + ctx.exit_field() + + +def _validate_subschemas( + subschemas: List[JsonSchema], + message: JsonMessage, + registry: Registry +) -> Optional[JsonSchema]: + for subschema in subschemas: + try: + validate(instance=message, schema=subschema, registry=registry) + return subschema + except ValidationError: + pass + return None + + +def get_type(schema: JsonSchema) -> FieldType: + if isinstance(schema, list): + return FieldType.COMBINED + elif isinstance(schema, dict): + schema_type = schema.get("type") + else: + # string schemas; this could be either a named schema or a primitive type + schema_type = schema + + if schema.get("const") is not None or schema.get("enum") is not None: + return FieldType.ENUM + if schema_type == "object": + props = schema.get("properties") + if not props: + return FieldType.MAP + return FieldType.RECORD + if schema_type == "array": + return FieldType.ARRAY + if schema_type == "string": + return FieldType.STRING + if schema_type == "integer": + return FieldType.INT + if schema_type == "number": + return FieldType.DOUBLE + if schema_type == "boolean": + return FieldType.BOOLEAN + if schema_type == "null": + return FieldType.NULL + return FieldType.NULL + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def get_inline_tags(schema: JsonSchema) -> Set[str]: + tags = schema.get("confluent:tags") + if tags is None: + return set() + else: + return set(tags) diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py new file mode 100644 index 000000000..bb8c9b6d4 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -0,0 +1,358 @@ +import io +import sys +import base64 +from collections import deque +from decimal import Context, Decimal, MAX_PREC +from typing import Set, List, Any + +from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ + duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ + type_pb2, wrappers_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.type import calendar_period_pb2, color_pb2, date_pb2, datetime_pb2, \ + dayofweek_pb2, expr_pb2, fraction_pb2, latlng_pb2, money_pb2, month_pb2, \ + postal_address_pb2, timeofday_pb2, quaternion_pb2 + +import confluent_kafka.schema_registry.confluent.meta_pb2 as meta_pb2 + +from google.protobuf.descriptor import Descriptor, FieldDescriptor, \ + FileDescriptor +from google.protobuf.message import DecodeError, Message + +from confluent_kafka.schema_registry.confluent.types import decimal_pb2 +from confluent_kafka.schema_registry import RuleKind +from confluent_kafka.serialization import SerializationError +from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, \ + FieldType, RuleConditionError + +# Convert an int to bytes (inverse of ord()) +# Python3.chr() -> Unicode +# Python2.chr() -> str(alias for bytes) +if sys.version > '3': + def _bytes(v: int) -> bytes: + """ + Convert int to bytes + + Args: + v (int): The int to convert to bytes. + """ + return bytes((v,)) +else: + def _bytes(v: int) -> str: + """ + Convert int to bytes + + Args: + v (int): The int to convert to bytes. + """ + return chr(v) + + +def _create_index_array(msg_desc: Descriptor) -> List[int]: + """ + Creates an index array specifying the location of msg_desc in + the referenced FileDescriptor. + + Args: + msg_desc (MessageDescriptor): Protobuf MessageDescriptor + + Returns: + list of int: Protobuf MessageDescriptor index array. + + Raises: + ValueError: If the message descriptor is malformed. + """ + + msg_idx = deque() + + # Walk the nested MessageDescriptor tree up to the root. + current = msg_desc + found = False + while current.containing_type is not None: + previous = current + current = previous.containing_type + # find child's position + for idx, node in enumerate(current.nested_types): + if node == previous: + msg_idx.appendleft(idx) + found = True + break + if not found: + raise ValueError("Nested MessageDescriptor not found") + + # Add the index of the root MessageDescriptor in the FileDescriptor. + found = False + for idx, msg_type_name in enumerate(msg_desc.file.message_types_by_name): + if msg_type_name == current.name: + msg_idx.appendleft(idx) + found = True + break + if not found: + raise ValueError("MessageDescriptor not found in file") + + return list(msg_idx) + + +def _schema_to_str(file_descriptor: FileDescriptor) -> str: + """ + Base64 encode a FileDescriptor + + Args: + file_descriptor (FileDescriptor): FileDescriptor to encode. + + Returns: + str: Base64 encoded FileDescriptor + """ + + return base64.standard_b64encode(file_descriptor.serialized_pb).decode('ascii') + + +def _proto_to_str(file_descriptor_proto: descriptor_pb2.FileDescriptorProto) -> str: + """ + Base64 encode a FileDescriptorProto + + Args: + file_descriptor_proto (FileDescriptorProto): FileDescriptorProto to encode. + + Returns: + str: Base64 encoded FileDescriptorProto + """ + + return base64.standard_b64encode(file_descriptor_proto.SerializeToString()).decode('ascii') + + +def _str_to_proto(name: str, schema_str: str) -> descriptor_pb2.FileDescriptorProto: + """ + Base64 decode a FileDescriptor + + Args: + schema_str (str): Base64 encoded FileDescriptorProto + + Returns: + FileDescriptorProto: schema. + """ + + serialized_pb = base64.standard_b64decode(schema_str.encode('ascii')) + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + try: + file_descriptor_proto.ParseFromString(serialized_pb) + file_descriptor_proto.name = name + except DecodeError as e: + raise SerializationError(str(e)) + return file_descriptor_proto + + +def _init_pool(pool: DescriptorPool): + pool.AddSerializedFile(any_pb2.DESCRIPTOR.serialized_pb) + # source_context needed by api + pool.AddSerializedFile(source_context_pb2.DESCRIPTOR.serialized_pb) + # type needed by api + pool.AddSerializedFile(type_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(api_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(descriptor_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(empty_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(field_mask_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(wrappers_pb2.DESCRIPTOR.serialized_pb) + + pool.AddSerializedFile(calendar_period_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(color_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(date_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(datetime_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(dayofweek_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(expr_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(fraction_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(latlng_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(money_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(month_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(postal_address_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(quaternion_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(timeofday_pb2.DESCRIPTOR.serialized_pb) + + pool.AddSerializedFile(meta_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(decimal_pb2.DESCRIPTOR.serialized_pb) + + +def transform( + ctx: RuleContext, descriptor: Descriptor, message: Any, + field_transform: FieldTransform +) -> Any: + if message is None or descriptor is None: + return message + if isinstance(message, list): + return [transform(ctx, descriptor, item, field_transform) + for item in message] + if isinstance(message, dict): + return {key: transform(ctx, descriptor, value, field_transform) + for key, value in message.items()} + if isinstance(message, Message): + for fd in descriptor.fields: + _transform_field(ctx, fd, descriptor, message, field_transform) + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, fd: FieldDescriptor, desc: Descriptor, + message: Message, field_transform: FieldTransform +): + try: + ctx.enter_field( + message, + fd.full_name, + fd.name, + get_type(fd), + get_inline_tags(fd) + ) + if fd.containing_oneof is not None and not message.HasField(fd.name): + return + value = getattr(message, fd.name) + if is_map_field(fd): + value = {key: value[key] for key in value} + elif fd.label == FieldDescriptor.LABEL_REPEATED: + value = [item for item in value] + new_value = transform(ctx, desc, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + _set_field(fd, message, new_value) + finally: + ctx.exit_field() + + +def _set_field(fd: FieldDescriptor, message: Message, value: Any): + if isinstance(value, list): + message.ClearField(fd.name) + old_value = getattr(message, fd.name) + old_value.extend(value) + elif isinstance(value, dict): + message.ClearField(fd.name) + old_value = getattr(message, fd.name) + old_value.update(value) + else: + setattr(message, fd.name, value) + + +def get_type(fd: FieldDescriptor) -> FieldType: + if is_map_field(fd): + return FieldType.MAP + if fd.type == FieldDescriptor.TYPE_MESSAGE: + return FieldType.RECORD + if fd.type == FieldDescriptor.TYPE_ENUM: + return FieldType.ENUM + if fd.type == FieldDescriptor.TYPE_STRING: + return FieldType.STRING + if fd.type == FieldDescriptor.TYPE_BYTES: + return FieldType.BYTES + if fd.type in (FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_SINT32, + FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, + FieldDescriptor.TYPE_SFIXED32): + return FieldType.INT + if fd.type in (FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_SINT64, + FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, + FieldDescriptor.TYPE_SFIXED64): + return FieldType.LONG + if fd.type == FieldDescriptor.TYPE_FLOAT: + return FieldType.FLOAT + if fd.type == FieldDescriptor.TYPE_DOUBLE: + return FieldType.DOUBLE + if fd.type == FieldDescriptor.TYPE_BOOL: + return FieldType.BOOLEAN + return FieldType.NULL + + +def is_map_field(fd: FieldDescriptor): + return (fd.type == FieldDescriptor.TYPE_MESSAGE + and hasattr(fd.message_type, 'options') + and fd.message_type.options.map_entry) + + +def get_inline_tags(fd: FieldDescriptor) -> Set[str]: + meta = fd.GetOptions().Extensions[meta_pb2.field_meta] + if meta is None: + return set() + else: + return set(meta.tags) + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def _is_builtin(name: str) -> bool: + return name.startswith('confluent/') or \ + name.startswith('google/protobuf/') or \ + name.startswith('google/type/') + + +def decimalToProtobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: + """ + Converts a Decimal to a Protobuf value. + + Args: + value (Decimal): The Decimal value to convert. + + Returns: + The Protobuf value. + """ + sign, digits, exp = value.as_tuple() + + delta = exp + scale + + if delta < 0: + raise ValueError( + "Scale provided does not match the decimal") + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + unscaled_datum = 10**delta * unscaled_datum + + bytes_req = (unscaled_datum.bit_length() + 8) // 8 + + if sign: + unscaled_datum = -unscaled_datum + + bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) + + result = decimal_pb2.Decimal() + result.value = bytes + result.precision = 0 + result.scale = scale + return result + + +decimal_context = Context() + + +def protobufToDecimal(value: decimal_pb2.Decimal) -> Decimal: + """ + Converts a Protobuf value to Decimal. + + Args: + value (decimal_pb2.Decimal): The Protobuf value to convert. + + Returns: + The Decimal value. + """ + unscaled_datum = int.from_bytes(value.value, byteorder="big", signed=True) + + if value.precision > 0: + decimal_context.prec = value.precision + else: + decimal_context.prec = MAX_PREC + return decimal_context.create_decimal(unscaled_datum).scaleb( + -value.scale, decimal_context + ) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py new file mode 100644 index 000000000..ca8efbe18 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -0,0 +1,897 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import random + +from attrs import define as _attrs_define +from attrs import field as _attrs_field +from collections import defaultdict +from enum import Enum +from threading import Lock +from typing import List, Dict, Type, TypeVar, \ + cast, Optional, Any + +VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] + +class _BearerFieldProvider(metaclass=abc.ABCMeta): + @abc.abstractmethod + def get_bearer_fields(self) -> dict: + raise NotImplementedError + + +def is_success(status_code: int) -> bool: + return 200 <= status_code <= 299 + + +def is_retriable(status_code: int) -> bool: + return status_code in (408, 429, 500, 502, 503, 504) + + +def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: + no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) + return random.random() * min(no_jitter_delay, max_delay_ms) + + +class _SchemaCache(object): + """ + Thread-safe cache for use with the Schema Registry Client. + + This cache may be used to retrieve schema ids, schemas or to check + known subject membership. + """ + + def __init__(self): + self.lock = Lock() + self.schema_id_index = defaultdict(dict) + self.schema_index = defaultdict(dict) + self.rs_id_index = defaultdict(dict) + self.rs_version_index = defaultdict(dict) + self.rs_schema_index = defaultdict(dict) + + def set_schema(self, subject: str, schema_id: int, schema: 'Schema'): + """ + Add a Schema identified by schema_id to the cache. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): Schema's registration id + + schema (Schema): Schema instance + """ + + with self.lock: + self.schema_id_index[subject][schema_id] = schema + self.schema_index[subject][schema] = schema_id + + def set_registered_schema(self, schema: 'Schema', registered_schema: 'RegisteredSchema'): + """ + Add a RegisteredSchema to the cache. + + Args: + registered_schema (RegisteredSchema): RegisteredSchema instance + """ + + subject = registered_schema.subject + schema_id = registered_schema.schema_id + version = registered_schema.version + with self.lock: + self.schema_id_index[subject][schema_id] = schema + self.schema_index[subject][schema] = schema_id + self.rs_id_index[subject][schema_id] = registered_schema + self.rs_version_index[subject][version] = registered_schema + self.rs_schema_index[subject][schema] = registered_schema + + def get_schema_by_id(self, subject: str, schema_id: int) -> Optional['Schema']: + """ + Get the schema instance associated with schema id from the cache. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): Id used to identify a schema + + Returns: + Schema: The schema if known; else None + """ + + with self.lock: + return self.schema_id_index.get(subject, {}).get(schema_id, None) + + def get_id_by_schema(self, subject: str, schema: 'Schema') -> Optional[int]: + """ + Get the schema id associated with schema instance from the cache. + + Args: + subject (str): The subject this schema is associated with + + schema (Schema): The schema + + Returns: + int: The schema id if known; else None + """ + + with self.lock: + return self.schema_index.get(subject, {}).get(schema, None) + + def get_registered_by_subject_schema(self, subject: str, schema: 'Schema') -> Optional['RegisteredSchema']: + """ + Get the schema associated with this schema registered under subject. + + Args: + subject (str): The subject this schema is associated with + + schema (Schema): The schema associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_schema_index.get(subject, {}).get(schema, None) + + def get_registered_by_subject_id(self, subject: str, schema_id: int) -> Optional['RegisteredSchema']: + """ + Get the schema associated with this id registered under subject. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): The schema id associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_id_index.get(subject, {}).get(schema_id, None) + + def get_registered_by_subject_version(self, subject: str, version: int) -> Optional['RegisteredSchema']: + """ + Get the schema associated with this version registered under subject. + + Args: + subject (str): The subject this schema is associated with + + version (int): The version associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_version_index.get(subject, {}).get(version, None) + + def remove_by_subject(self, subject: str): + """ + Remove schemas with the given subject. + + Args: + subject (str): The subject + """ + + with self.lock: + if subject in self.schema_id_index: + del self.schema_id_index[subject] + if subject in self.schema_index: + del self.schema_index[subject] + if subject in self.rs_id_index: + del self.rs_id_index[subject] + if subject in self.rs_version_index: + del self.rs_version_index[subject] + if subject in self.rs_schema_index: + del self.rs_schema_index[subject] + + def remove_by_subject_version(self, subject: str, version: int): + """ + Remove schemas with the given subject. + + Args: + subject (str): The subject + + version (int) The version + """ + + with self.lock: + if subject in self.rs_id_index: + for schema_id, registered_schema in self.rs_id_index[subject].items(): + if registered_schema.version == version: + del self.rs_schema_index[subject][schema_id] + if subject in self.rs_schema_index: + for schema, registered_schema in self.rs_schema_index[subject].items(): + if registered_schema.version == version: + del self.rs_schema_index[subject][schema] + rs = None + if subject in self.rs_version_index: + if version in self.rs_version_index[subject]: + rs = self.rs_version_index[subject][version] + del self.rs_version_index[subject][version] + if rs is not None: + if subject in self.schema_id_index: + if rs.schema_id in self.schema_id_index[subject]: + del self.schema_id_index[subject][rs.schema_id] + if rs.schema in self.schema_index[subject]: + del self.schema_index[subject][rs.schema] + + def clear(self): + """ + Clear the cache. + """ + + with self.lock: + self.schema_id_index.clear() + self.schema_index.clear() + self.rs_id_index.clear() + self.rs_version_index.clear() + self.rs_schema_index.clear() + +T = TypeVar("T") + + +class RuleKind(str, Enum): + CONDITION = "CONDITION" + TRANSFORM = "TRANSFORM" + + def __str__(self) -> str: + return str(self.value) + + +class RuleMode(str, Enum): + UPGRADE = "UPGRADE" + DOWNGRADE = "DOWNGRADE" + UPDOWN = "UPDOWN" + READ = "READ" + WRITE = "WRITE" + WRITEREAD = "WRITEREAD" + + def __str__(self) -> str: + return str(self.value) + + +@_attrs_define +class RuleParams: + params: Dict[str, str] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + field_dict.update(self.params) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + rule_params = cls(params=d) + + return rule_params + + def __hash__(self): + return hash(frozenset(self.params.items())) + + +@_attrs_define(frozen=True) +class Rule: + name: Optional[str] + doc: Optional[str] + kind: Optional[RuleKind] + mode: Optional[RuleMode] + type: Optional[str] + tags: Optional[List[str]] = _attrs_field(hash=False) + params: Optional[RuleParams] + expr: Optional[str] + on_success: Optional[str] + on_failure: Optional[str] + disabled: Optional[bool] + + def to_dict(self) -> Dict[str, Any]: + name = self.name + + doc = self.doc + + kind_str: Optional[str] = None + if self.kind is not None: + kind_str = self.kind.value + + mode_str: Optional[str] = None + if self.mode is not None: + mode_str = self.mode.value + + rule_type = self.type + + tags = self.tags + + _params: Optional[Dict[str, Any]] = None + if self.params is not None: + _params = self.params.to_dict() + + expr = self.expr + + on_success = self.on_success + + on_failure = self.on_failure + + disabled = self.disabled + + field_dict: Dict[str, Any] = {} + field_dict.update({}) + if name is not None: + field_dict["name"] = name + if doc is not None: + field_dict["doc"] = doc + if kind_str is not None: + field_dict["kind"] = kind_str + if mode_str is not None: + field_dict["mode"] = mode_str + if type is not None: + field_dict["type"] = rule_type + if tags is not None: + field_dict["tags"] = tags + if _params is not None: + field_dict["params"] = _params + if expr is not None: + field_dict["expr"] = expr + if on_success is not None: + field_dict["onSuccess"] = on_success + if on_failure is not None: + field_dict["onFailure"] = on_failure + if disabled is not None: + field_dict["disabled"] = disabled + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + name = d.pop("name", None) + + doc = d.pop("doc", None) + + _kind = d.pop("kind", None) + kind: Optional[RuleKind] = None + if _kind is not None: + kind = RuleKind(_kind) + + _mode = d.pop("mode", None) + mode: Optional[RuleMode] = None + if _mode is not None: + mode = RuleMode(_mode) + + rule_type = d.pop("type", None) + + tags = cast(List[str], d.pop("tags", None)) + + _params: Optional[Dict[str, Any]] = d.pop("params", None) + params: Optional[RuleParams] = None + if _params is not None: + params = RuleParams.from_dict(_params) + + expr = d.pop("expr", None) + + on_success = d.pop("onSuccess", None) + + on_failure = d.pop("onFailure", None) + + disabled = d.pop("disabled", None) + + rule = cls( + name=name, + doc=doc, + kind=kind, + mode=mode, + type=rule_type, + tags=tags, + params=params, + expr=expr, + on_success=on_success, + on_failure=on_failure, + disabled=disabled, + ) + + return rule + + +@_attrs_define +class RuleSet: + migration_rules: Optional[List["Rule"]] = _attrs_field(hash=False) + domain_rules: Optional[List["Rule"]] = _attrs_field(hash=False) + + def to_dict(self) -> Dict[str, Any]: + _migration_rules: Optional[List[Dict[str, Any]]] = None + if self.migration_rules is not None: + _migration_rules = [] + for migration_rules_item_data in self.migration_rules: + migration_rules_item = migration_rules_item_data.to_dict() + _migration_rules.append(migration_rules_item) + + _domain_rules: Optional[List[Dict[str, Any]]] = None + if self.domain_rules is not None: + _domain_rules = [] + for domain_rules_item_data in self.domain_rules: + domain_rules_item = domain_rules_item_data.to_dict() + _domain_rules.append(domain_rules_item) + + field_dict: Dict[str, Any] = {} + field_dict.update({}) + if _migration_rules is not None: + field_dict["migrationRules"] = _migration_rules + if _domain_rules is not None: + field_dict["domainRules"] = _domain_rules + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + migration_rules = [] + _migration_rules = d.pop("migrationRules", None) + for migration_rules_item_data in _migration_rules or []: + migration_rules_item = Rule.from_dict(migration_rules_item_data) + migration_rules.append(migration_rules_item) + + domain_rules = [] + _domain_rules = d.pop("domainRules", None) + for domain_rules_item_data in _domain_rules or []: + domain_rules_item = Rule.from_dict(domain_rules_item_data) + domain_rules.append(domain_rules_item) + + rule_set = cls( + migration_rules=migration_rules, + domain_rules=domain_rules, + ) + + return rule_set + + def __hash__(self): + return hash(frozenset((self.migration_rules or []) + (self.domain_rules or []))) + + +@_attrs_define +class MetadataTags: + tags: Dict[str, List[str]] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + for prop_name, prop in self.tags.items(): + field_dict[prop_name] = prop + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + tags = {} + for prop_name, prop_dict in d.items(): + tag = cast(List[str], prop_dict) + + tags[prop_name] = tag + + metadata_tags = cls(tags=tags) + + return metadata_tags + + def __hash__(self): + return hash(frozenset(self.tags.items())) + + +@_attrs_define +class MetadataProperties: + properties: Dict[str, str] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + field_dict.update(self.properties) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + metadata_properties = cls(properties=d) + + return metadata_properties + + def __hash__(self): + return hash(frozenset(self.properties.items())) + + +@_attrs_define(frozen=True) +class Metadata: + tags: Optional[MetadataTags] + properties: Optional[MetadataProperties] + sensitive: Optional[List[str]] = _attrs_field(hash=False) + + def to_dict(self) -> Dict[str, Any]: + _tags: Optional[Dict[str, Any]] = None + if self.tags is not None: + _tags = self.tags.to_dict() + + _properties: Optional[Dict[str, Any]] = None + if self.properties is not None: + _properties = self.properties.to_dict() + + sensitive: Optional[List[str]] = None + if self.sensitive is not None: + sensitive = [] + for sensitive_item in self.sensitive: + sensitive.append(sensitive_item) + + field_dict: Dict[str, Any] = {} + if _tags is not None: + field_dict["tags"] = _tags + if _properties is not None: + field_dict["properties"] = _properties + if sensitive is not None: + field_dict["sensitive"] = sensitive + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + _tags: Optional[Dict[str, Any]] = d.pop("tags", None) + tags: Optional[MetadataTags] = None + if _tags is not None: + tags = MetadataTags.from_dict(_tags) + + _properties: Optional[Dict[str, Any]] = d.pop("properties", None) + properties: Optional[MetadataProperties] = None + if _properties is not None: + properties = MetadataProperties.from_dict(_properties) + + sensitive = [] + _sensitive = d.pop("sensitive", None) + for sensitive_item in _sensitive or []: + sensitive.append(sensitive_item) + + metadata = cls( + tags=tags, + properties=properties, + sensitive=sensitive, + ) + + return metadata + + +@_attrs_define(frozen=True) +class SchemaReference: + name: Optional[str] + subject: Optional[str] + version: Optional[int] + + def to_dict(self) -> Dict[str, Any]: + name = self.name + + subject = self.subject + + version = self.version + + field_dict: Dict[str, Any] = {} + if name is not None: + field_dict["name"] = name + if subject is not None: + field_dict["subject"] = subject + if version is not None: + field_dict["version"] = version + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + name = d.pop("name", None) + + subject = d.pop("subject", None) + + version = d.pop("version", None) + + schema_reference = cls( + name=name, + subject=subject, + version=version, + ) + + return schema_reference + + +class ConfigCompatibilityLevel(str, Enum): + BACKWARD = "BACKWARD" + BACKWARD_TRANSITIVE = "BACKWARD_TRANSITIVE" + FORWARD = "FORWARD" + FORWARD_TRANSITIVE = "FORWARD_TRANSITIVE" + FULL = "FULL" + FULL_TRANSITIVE = "FULL_TRANSITIVE" + NONE = "NONE" + + def __str__(self) -> str: + return str(self.value) + + +@_attrs_define +class ServerConfig: + compatibility: Optional[ConfigCompatibilityLevel] = None + compatibility_level: Optional[ConfigCompatibilityLevel] = None + compatibility_group: Optional[str] = None + default_metadata: Optional[Metadata] = None + override_metadata: Optional[Metadata] = None + default_rule_set: Optional[RuleSet] = None + override_rule_set: Optional[RuleSet] = None + + def to_dict(self) -> Dict[str, Any]: + _compatibility: Optional[str] = None + if self.compatibility is not None: + _compatibility = self.compatibility.value + + _compatibility_level: Optional[str] = None + if self.compatibility_level is not None: + _compatibility_level = self.compatibility_level.value + + compatibility_group = self.compatibility_group + + _default_metadata: Optional[Dict[str, Any]] + if isinstance(self.default_metadata, Metadata): + _default_metadata = self.default_metadata.to_dict() + else: + _default_metadata = self.default_metadata + + _override_metadata: Optional[Dict[str, Any]] + if isinstance(self.override_metadata, Metadata): + _override_metadata = self.override_metadata.to_dict() + else: + _override_metadata = self.override_metadata + + _default_rule_set: Optional[Dict[str, Any]] + if isinstance(self.default_rule_set, RuleSet): + _default_rule_set = self.default_rule_set.to_dict() + else: + _default_rule_set = self.default_rule_set + + _override_rule_set: Optional[Dict[str, Any]] + if isinstance(self.override_rule_set, RuleSet): + _override_rule_set = self.override_rule_set.to_dict() + else: + _override_rule_set = self.override_rule_set + + field_dict: Dict[str, Any] = {} + if _compatibility is not None: + field_dict["compatibility"] = _compatibility + if _compatibility_level is not None: + field_dict["compatibilityLevel"] = _compatibility_level + if compatibility_group is not None: + field_dict["compatibilityGroup"] = compatibility_group + if _default_metadata is not None: + field_dict["defaultMetadata"] = _default_metadata + if _override_metadata is not None: + field_dict["overrideMetadata"] = _override_metadata + if _default_rule_set is not None: + field_dict["defaultRuleSet"] = _default_rule_set + if _override_rule_set is not None: + field_dict["overrideRuleSet"] = _override_rule_set + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + _compatibility = d.pop("compatibility", None) + compatibility: Optional[ConfigCompatibilityLevel] + if _compatibility is None: + compatibility = None + else: + compatibility = ConfigCompatibilityLevel(_compatibility) + + _compatibility_level = d.pop("compatibilityLevel", None) + compatibility_level: Optional[ConfigCompatibilityLevel] + if _compatibility_level is None: + compatibility_level = None + else: + compatibility_level = ConfigCompatibilityLevel(_compatibility_level) + + compatibility_group = d.pop("compatibilityGroup", None) + + def _parse_default_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + default_metadata = _parse_default_metadata(d.pop("defaultMetadata", None)) + + def _parse_override_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + override_metadata = _parse_override_metadata(d.pop("overrideMetadata", None)) + + def _parse_default_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + default_rule_set = _parse_default_rule_set(d.pop("defaultRuleSet", None)) + + def _parse_override_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + override_rule_set = _parse_override_rule_set(d.pop("overrideRuleSet", None)) + + config = cls( + compatibility=compatibility, + compatibility_level=compatibility_level, + compatibility_group=compatibility_group, + default_metadata=default_metadata, + override_metadata=override_metadata, + default_rule_set=default_rule_set, + override_rule_set=override_rule_set, + ) + + return config + + +@_attrs_define(frozen=True, cache_hash=True) +class Schema: + """ + An unregistered schema. + """ + + schema_str: Optional[str] + schema_type: Optional[str] = "AVRO" + references: Optional[List[SchemaReference]] = _attrs_field(factory=list, hash=False) + metadata: Optional[Metadata] = None + rule_set: Optional[RuleSet] = None + + def to_dict(self) -> Dict[str, Any]: + schema = self.schema_str + + schema_type = self.schema_type + + _references: Optional[List[Dict[str, Any]]] = [] + if self.references is not None: + for references_item_data in self.references: + references_item = references_item_data.to_dict() + _references.append(references_item) + + _metadata: Optional[Dict[str, Any]] = None + if isinstance(self.metadata, Metadata): + _metadata = self.metadata.to_dict() + + _rule_set: Optional[Dict[str, Any]] = None + if isinstance(self.rule_set, RuleSet): + _rule_set = self.rule_set.to_dict() + + field_dict: Dict[str, Any] = {} + if schema is not None: + field_dict["schema"] = schema + if schema_type is not None: + field_dict["schemaType"] = schema_type + if _references is not None: + field_dict["references"] = _references + if _metadata is not None: + field_dict["metadata"] = _metadata + if _rule_set is not None: + field_dict["ruleSet"] = _rule_set + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + schema = d.pop("schema", None) + + schema_type = d.pop("schemaType", "AVRO") + + references = [] + _references = d.pop("references", None) + for references_item_data in _references or []: + references_item = SchemaReference.from_dict(references_item_data) + + references.append(references_item) + + def _parse_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + metadata = _parse_metadata(d.pop("metadata", None)) + + def _parse_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + rule_set = _parse_rule_set(d.pop("ruleSet", None)) + + schema = cls( + schema_str=schema, + schema_type=schema_type, + references=references, + metadata=metadata, + rule_set=rule_set, + ) + + return schema + + +@_attrs_define(frozen=True, cache_hash=True) +class RegisteredSchema: + """ + An registered schema. + """ + + schema_id: Optional[int] + schema: Optional[Schema] + subject: Optional[str] + version: Optional[int] + + def to_dict(self) -> Dict[str, Any]: + schema = self.schema + + schema_id = self.schema_id + + subject = self.subject + + version = self.version + + field_dict: Dict[str, Any] = {} + if schema is not None: + field_dict = schema.to_dict() + if schema_id is not None: + field_dict["id"] = schema_id + if subject is not None: + field_dict["subject"] = subject + if version is not None: + field_dict["version"] = version + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + schema = Schema.from_dict(d) + + schema_id = d.pop("id", None) + + subject = d.pop("subject", None) + + version = d.pop("version", None) + + schema = cls( + schema_id=schema_id, + schema=schema, + subject=subject, + version=version, + ) + + return schema diff --git a/src/confluent_kafka/schema_registry/common/serde.py b/src/confluent_kafka/schema_registry/common/serde.py new file mode 100644 index 000000000..864adca0c --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/serde.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import logging +from enum import Enum +from threading import Lock +from typing import Callable, List, Optional, Set, Dict, Any, TypeVar + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema +from confluent_kafka.schema_registry.wildcard_matcher import wildcard_match +from confluent_kafka.serialization import SerializationContext, SerializationError + + +log = logging.getLogger(__name__) + + +class FieldType(str, Enum): + RECORD = "RECORD" + ENUM = "ENUM" + ARRAY = "ARRAY" + MAP = "MAP" + COMBINED = "COMBINED" + FIXED = "FIXED" + STRING = "STRING" + BYTES = "BYTES" + INT = "INT" + LONG = "LONG" + FLOAT = "FLOAT" + DOUBLE = "DOUBLE" + BOOLEAN = "BOOLEAN" + NULL = "NULL" + + +class FieldContext(object): + __slots__ = ['containing_message', 'full_name', 'name', 'field_type', 'tags'] + + def __init__( + self, containing_message: Any, full_name: str, name: str, + field_type: FieldType, tags: Set[str] + ): + self.containing_message = containing_message + self.full_name = full_name + self.name = name + self.field_type = field_type + self.tags = tags + + def is_primitive(self) -> bool: + return self.field_type in (FieldType.INT, FieldType.LONG, FieldType.FLOAT, + FieldType.DOUBLE, FieldType.BOOLEAN, FieldType.NULL, + FieldType.STRING, FieldType.BYTES) + + def type_name(self) -> str: + return self.field_type.name + + +class RuleContext(object): + __slots__ = ['ser_ctx', 'source', 'target', 'subject', 'rule_mode', 'rule', + 'index', 'rules', 'inline_tags', 'field_transformer', '_field_contexts'] + + def __init__( + self, ser_ctx: SerializationContext, source: Optional[Schema], + target: Optional[Schema], subject: str, rule_mode: RuleMode, rule: Rule, + index: int, rules: List[Rule], inline_tags: Optional[Dict[str, Set[str]]], field_transformer + ): + self.ser_ctx = ser_ctx + self.source = source + self.target = target + self.subject = subject + self.rule_mode = rule_mode + self.rule = rule + self.index = index + self.rules = rules + self.inline_tags = inline_tags + self.field_transformer = field_transformer + self._field_contexts: List[FieldContext] = [] + + def get_parameter(self, name: str) -> Optional[str]: + params = self.rule.params + if params is not None: + value = params.params.get(name) + if value is not None: + return value + if (self.target is not None + and self.target.metadata is not None + and self.target.metadata.properties is not None): + value = self.target.metadata.properties.properties.get(name) + if value is not None: + return value + return None + + def _get_inline_tags(self, name: str) -> Set[str]: + if self.inline_tags is None: + return set() + return self.inline_tags.get(name, set()) + + def current_field(self) -> Optional[FieldContext]: + if not self._field_contexts: + return None + return self._field_contexts[-1] + + def enter_field( + self, containing_message: Any, full_name: str, name: str, + field_type: FieldType, tags: Optional[Set[str]] + ) -> FieldContext: + all_tags = set(tags if tags is not None else self._get_inline_tags(full_name)) + all_tags.update(self.get_tags(full_name)) + field_context = FieldContext(containing_message, full_name, name, field_type, all_tags) + self._field_contexts.append(field_context) + return field_context + + def get_tags(self, full_name: str) -> Set[str]: + result = set() + if (self.target is not None + and self.target.metadata is not None + and self.target.metadata.tags is not None): + tags = self.target.metadata.tags.tags + for k, v in tags.items(): + if wildcard_match(full_name, k): + result.update(v) + return result + + def exit_field(self): + if self._field_contexts: + self._field_contexts.pop() + + +FieldTransform = Callable[[RuleContext, FieldContext, Any], Any] + + +FieldTransformer = Callable[[RuleContext, FieldTransform, Any], Any] + + +class RuleBase(metaclass=abc.ABCMeta): + def configure(self, client_conf: dict, rule_conf: dict): + pass + + @abc.abstractmethod + def type(self) -> str: + raise NotImplementedError() + + def close(self): + pass + + +class RuleExecutor(RuleBase): + @abc.abstractmethod + def transform(self, ctx: RuleContext, message: Any) -> Any: + raise NotImplementedError() + + +class FieldRuleExecutor(RuleExecutor): + @abc.abstractmethod + def new_transform(self, ctx: RuleContext) -> FieldTransform: + raise NotImplementedError() + + def transform(self, ctx: RuleContext, message: Any) -> Any: + # TODO preserve source + if ctx.rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + for i in range(ctx.index): + other_rule = ctx.rules[i] + if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): + # ignore this transform if an earlier one has the same tag + return message + elif ctx.rule_mode == RuleMode.READ or ctx.rule_mode == RuleMode.DOWNGRADE: + for i in range(ctx.index + 1, len(ctx.rules)): + other_rule = ctx.rules[i] + if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): + # ignore this transform if a later one has the same tag + return message + return ctx.field_transformer(ctx, self.new_transform(ctx), message) + + @staticmethod + def are_transforms_with_same_tag(rule1: Rule, rule2: Rule) -> bool: + return (bool(rule1.tags) + and rule1.kind == RuleKind.TRANSFORM + and rule1.kind == rule2.kind + and rule1.mode == rule2.mode + and rule1.type == rule2.type + and rule1.tags == rule2.tags) + + +class RuleAction(RuleBase): + @abc.abstractmethod + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + raise NotImplementedError() + + +class ErrorAction(RuleAction): + def type(self) -> str: + return 'ERROR' + + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + if ex is None: + raise SerializationError() + else: + raise SerializationError() from ex + + +class NoneAction(RuleAction): + def type(self) -> str: + return 'NONE' + + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + pass + + +class RuleError(Exception): + pass + + +class RuleConditionError(RuleError): + def __init__(self, rule: Rule): + super().__init__(RuleConditionError.error_message(rule)) + + @staticmethod + def error_message(rule: Rule) -> str: + if rule.doc: + return rule.doc + elif rule.expr: + return f"Rule expr failed: {rule.expr}" + else: + return f"Rule failed: {rule.name}" + + +class Migration(object): + __slots__ = ['rule_mode', 'source', 'target'] + + def __init__( + self, rule_mode: RuleMode, source: Optional[RegisteredSchema], + target: Optional[RegisteredSchema] + ): + self.rule_mode = rule_mode + self.source = source + self.target = target + +T = TypeVar("T") + +class ParsedSchemaCache(object): + """ + Thread-safe cache for parsed schemas + """ + + def __init__(self): + self.lock = Lock() + self.parsed_schemas = {} + + def set(self, schema: Schema, parsed_schema: T): + """ + Add a Schema identified by schema_id to the cache. + + Args: + schema (Schema): The schema + + parsed_schema (Any): The parsed schema + """ + + with self.lock: + self.parsed_schemas[schema] = parsed_schema + + def get_parsed_schema(self, schema: Schema) -> Optional[T]: + """ + Get the parsed schema associated with the schema + + Args: + schema (Schema): The schema + + Returns: + The parsed schema if known; else None + """ + + with self.lock: + return self.parsed_schemas.get(schema, None) + + def clear(self): + """ + Clear the cache. + """ + + with self.lock: + self.parsed_schemas.clear() diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index 157d0dd7f..cff914d96 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -15,792 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import decimal -from io import BytesIO - -import json -import struct -from typing import Union, Optional, List, Set, Tuple, Callable - -import httpx -import referencing -from cachetools import LRUCache -from jsonschema import validate, ValidationError -from jsonschema.protocols import Validator -from jsonschema.validators import validator_for -from referencing import Registry, Resource -from referencing._core import Resolver - -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleKind, - RuleMode, SchemaRegistryClient) -from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import BaseSerializer, \ - BaseDeserializer, RuleContext, FieldTransform, FieldType, \ - RuleConditionError, ParsedSchemaCache -from confluent_kafka.serialization import (SerializationError, - SerializationContext) - - -JsonMessage = Union[ - None, # 'null' Avro type - str, # 'string' and 'enum' - float, # 'float' and 'double' - int, # 'int' and 'long' - decimal.Decimal, # 'fixed' - bool, # 'boolean' - list, # 'array' - dict, # 'map' and 'record' -] - -JsonSchema = Union[bool, dict] - -DEFAULT_SPEC = referencing.jsonschema.DRAFT7 - - -class _ContextStringIO(BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _retrieve_via_httpx(uri: str): - response = httpx.get(uri) - return Resource.from_contents( - response.json(), default_specification=DEFAULT_SPEC) - - -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient, - ref_registry: Optional[Registry] = None -) -> Registry: - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :param ref_registry: Registry of named schemas resolved recursively. - :return: Registry - """ - if ref_registry is None: - # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) - if schema.references is not None: - for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) - referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) - resource = Resource.from_contents( - referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) - return ref_registry - - -class JSONSerializer(BaseSerializer): - """ - Serializer that outputs JSON encoded data with Confluent Schema Registry framing. - - Configuration properties: - - +-----------------------------+----------+----------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+====================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - | | | | - | | | Raises SchemaRegistryError if the schema was not | - | | | registered against the subject, or could not be | - | | | successfully registered. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to validate the payload against the | - | ``validate`` | bool | the given schema. | - | | | | - +-----------------------------+----------+----------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies: - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Notes: - The ``title`` annotation, referred to elsewhere as a record name - is not strictly required by the JSON Schema specification. It is - however required by this serializer in order to register the schema - with Confluent Schema Registry. - - Prior to serialization, all objects must first be converted to - a dict instance. This may be handled manually prior to calling - :py:func:`Producer.produce()` or by registering a `to_dict` - callable with JSONSerializer. - - Args: - schema_str (str, Schema): - `JSON Schema definition. `_ - Accepts schema as either a string or a :py:class:`Schema` instance. - Note that string definitions cannot reference other schemas. For - referencing other schemas, use a :py:class:`Schema` instance. - - schema_registry_client (SchemaRegistryClient): Schema Registry - client instance. - - to_dict (callable, optional): Callable(object, SerializationContext) -> dict. - Converts object to a dict. - - conf (dict): JsonSerializer configuration. - """ # noqa: E501 - __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', - '_schema', '_schema_id', '_schema_name', '_to_dict', - '_parsed_schemas', '_validators', '_validate', '_json_encode'] - - _default_conf = {'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'validate': True} - - def __init__( - self, - schema_str: Union[str, Schema, None], - schema_registry_client: SchemaRegistryClient, - to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None, - json_encode: Optional[Callable] = None, - ): - super().__init__() - if isinstance(schema_str, str): - self._schema = Schema(schema_str, schema_type="JSON") - elif isinstance(schema_str, Schema): - self._schema = schema_str - else: - self._schema = None - - self._json_encode = json_encode or json.dumps - self._registry = schema_registry_client - self._rule_registry = ( - rule_registry if rule_registry else RuleRegistry.get_global_instance() - ) - self._schema_id = None - self._known_subjects = set() - self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) - - if to_dict is not None and not callable(to_dict): - raise ValueError("to_dict must be callable with the signature " - "to_dict(object, SerializationContext)->dict") - - self._to_dict = to_dict - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._validate = conf_copy.pop('validate') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("validate must be a boolean value") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - schema_dict, ref_registry = self._get_parsed_schema(self._schema) - if schema_dict: - schema_name = schema_dict.get('title', None) - else: - schema_name = None - - self._schema_name = schema_name - self._parsed_schema = schema_dict - self._ref_registry = ref_registry - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an object to JSON, prepending it with Confluent Schema Registry - framing. - - Args: - obj (object): The object instance to serialize. - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Raises: - SerializerError if any error occurs serializing obj. - - Returns: - bytes: None if obj is None, else a byte array containing the JSON - serialized data with Confluent Schema Registry framing. - """ - - if obj is None: - return None - - subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - # Check to ensure this schema has been registered under subject_name. - if self._auto_register: - # The schema name will always be the same. We can't however register - # a schema without a subject so we set the schema_id here to handle - # the initial registration. - self._schema_id = self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) - else: - registered_schema = self._registry.lookup_schema(subject, - self._schema, - self._normalize_schemas) - self._schema_id = registered_schema.schema_id - - self._known_subjects.add(subject) - - if self._to_dict is not None: - value = self._to_dict(obj, ctx) - else: - value = obj - - if latest_schema is not None: - schema = latest_schema.schema - parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) - else: - schema = self._schema - parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - - if self._validate: - try: - validator = self._get_validator(schema, parsed_schema, ref_registry) - validator.validate(value) - except ValidationError as ve: - raise SerializationError(ve.message) - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) - # JSON dump always writes a str never bytes - # https://docs.python.org/3/library/json.html - encoded_value = self._json_encode(value) - if isinstance(encoded_value, str): - encoded_value = encoded_value.encode("utf8") - fo.write(encoded_value) - - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: - if schema is None: - return None, None - - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - ref_registry = _resolve_named_schema(schema, self._registry) - parsed_schema = json.loads(schema.schema_str) - - self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) - return parsed_schema, ref_registry - - def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: - validator = self._validators.get(schema, None) - if validator is not None: - return validator - - cls = validator_for(parsed_schema) - cls.check_schema(parsed_schema) - validator = cls(parsed_schema, registry=registry) - - self._validators[schema] = validator - return validator - - -class JSONDeserializer(BaseDeserializer): - """ - Deserializer for JSON encoded data with Confluent Schema Registry - framing. - - Configuration properties: - - +-----------------------------+----------+----------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+====================================================+ - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to validate the payload against the | - | ``validate`` | bool | the given schema. | - | | | | - +-----------------------------+----------+----------------------------------------------------+ - - Args: - schema_str (str, Schema, optional): - `JSON schema definition `_ - Accepts schema as either a string or a :py:class:`Schema` instance. - Note that string definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. If not provided, schemas will be - retrieved from schema_registry_client based on the schema ID in the - wire header of each message. - - from_dict (callable, optional): Callable(dict, SerializationContext) -> object. - Converts a dict to a Python object instance. - - schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. - """ # noqa: E501 - - __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', - '_parsed_schemas', '_validators', '_validate', '_json_decode'] - - _default_conf = {'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'validate': True} - - def __init__( - self, - schema_str: Union[str, Schema, None], - from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, - schema_registry_client: Optional[SchemaRegistryClient] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None, - json_decode: Optional[Callable] = None, - ): - super().__init__() - if isinstance(schema_str, str): - schema = Schema(schema_str, schema_type="JSON") - elif isinstance(schema_str, Schema): - schema = schema_str - if bool(schema.references) and schema_registry_client is None: - raise ValueError( - """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") - elif schema_str is None: - if schema_registry_client is None: - raise ValueError( - """schema_registry_client must be provided if "schema_str" is not provided""" - ) - schema = schema_str - else: - raise TypeError('You must pass either str or Schema') - - self._schema = schema - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) - self._json_decode = json_decode or json.loads - self._use_schema_id = None - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._validate = conf_copy.pop('validate') - if not isinstance(self._validate, bool): - raise ValueError("validate must be a boolean value") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) - else: - self._reader_schema, self._ref_registry = None, None - - if from_dict is not None and not callable(from_dict): - raise ValueError("from_dict must be callable with the signature" - " from_dict(dict, SerializationContext) -> object") - - self._from_dict = from_dict - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: - """ - Deserialize a JSON encoded record with Confluent Schema Registry framing to - a dict, or object instance according to from_dict if from_dict is specified. - - Args: - data (bytes): A JSON serialized record with Confluent Schema Registry framing. - - ctx (SerializationContext): Metadata relevant to the serialization operation. - - Returns: - A dict, or object instance according to from_dict if from_dict is specified. - - Raises: - SerializerError: If there was an error reading the Confluent framing data, or - if ``data`` was not successfully validated with the configured schema. - """ - - if data is None: - return None - - if len(data) <= 5: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) - latest_schema = None - if subject is not None and self._registry is not None: - latest_schema = self._get_reader_schema(subject) - - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - # JSON documents are self-describing; no need to query schema - obj_dict = self._json_decode(payload.read()) - - if self._registry is not None: - writer_schema_raw = self._registry.get_schema(schema_id) - writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("title")) - if subject is not None: - latest_schema = self._get_reader_schema(subject) - else: - writer_schema_raw = None - writer_schema, writer_ref_registry = None, None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - - if migrations: - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) - - if self._validate: - try: - validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) - validator.validate(obj_dict) - except ValidationError as ve: - raise SerializationError(ve.message) - - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) - - return obj_dict - - def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: - if schema is None: - return None, None - - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - ref_registry = _resolve_named_schema(schema, self._registry) - parsed_schema = json.loads(schema.schema_str) - - self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) - return parsed_schema, ref_registry - - def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: - validator = self._validators.get(schema, None) - if validator is not None: - return validator - - cls = validator_for(parsed_schema) - cls.check_schema(parsed_schema) - validator = cls(parsed_schema, registry=registry) - - self._validators[schema] = validator - return validator - - -def transform( - ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, - path: str, message: JsonMessage, field_transform: FieldTransform -) -> Optional[JsonMessage]: - if message is None or schema is None or isinstance(schema, bool): - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - field_ctx.field_type = get_type(schema) - all_of = schema.get("allOf") - if all_of is not None: - subschema = _validate_subschemas(all_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - any_of = schema.get("anyOf") - if any_of is not None: - subschema = _validate_subschemas(any_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - one_of = schema.get("oneOf") - if one_of is not None: - subschema = _validate_subschemas(one_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - items = schema.get("items") - if items is not None: - if isinstance(message, list): - return [transform(ctx, items, ref_registry, ref_resolver, path, item, field_transform) for item in message] - ref = schema.get("$ref") - if ref is not None: - ref_schema = ref_resolver.lookup(ref) - return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) - schema_type = get_type(schema) - if schema_type == FieldType.RECORD: - props = schema.get("properties") - if props is not None: - for prop_name, prop_schema in props.items(): - _transform_field(ctx, path, prop_name, message, - prop_schema, ref_registry, ref_resolver, field_transform) - return message - if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, - prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform -): - full_name = path + "." + prop_name - try: - ctx.enter_field( - message, - full_name, - prop_name, - get_type(prop_schema), - get_inline_tags(prop_schema) - ) - value = message[prop_name] - new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - message[prop_name] = new_value - finally: - ctx.exit_field() - - -def _validate_subschemas( - subschemas: List[JsonSchema], - message: JsonMessage, - registry: Registry -) -> Optional[JsonSchema]: - for subschema in subschemas: - try: - validate(instance=message, schema=subschema, registry=registry) - return subschema - except ValidationError: - pass - return None - - -def get_type(schema: JsonSchema) -> FieldType: - if isinstance(schema, list): - return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema - - if schema.get("const") is not None or schema.get("enum") is not None: - return FieldType.ENUM - if schema_type == "object": - props = schema.get("properties") - if not props: - return FieldType.MAP - return FieldType.RECORD - if schema_type == "array": - return FieldType.ARRAY - if schema_type == "string": - return FieldType.STRING - if schema_type == "integer": - return FieldType.INT - if schema_type == "number": - return FieldType.DOUBLE - if schema_type == "boolean": - return FieldType.BOOLEAN - if schema_type == "null": - return FieldType.NULL - return FieldType.NULL - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def get_inline_tags(schema: JsonSchema) -> Set[str]: - tags = schema.get("confluent:tags") - if tags is None: - return set() - else: - return set(tags) +from .common.json_schema import * +from ._sync.json_schema import * diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index 39c2ce108..135afca16 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -15,1138 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io -import sys -import base64 -import struct -import warnings -from collections import deque -from decimal import Context, Decimal, MAX_PREC -from typing import Set, List, Union, Optional, Any, Tuple - -from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ - duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ - type_pb2, wrappers_pb2 -from google.protobuf import json_format -from google.protobuf.descriptor_pool import DescriptorPool -from google.type import calendar_period_pb2, color_pb2, date_pb2, datetime_pb2, \ - dayofweek_pb2, expr_pb2, fraction_pb2, latlng_pb2, money_pb2, month_pb2, \ - postal_address_pb2, timeofday_pb2, quaternion_pb2 - -import confluent_kafka.schema_registry.confluent.meta_pb2 as meta_pb2 - -from google.protobuf.descriptor import Descriptor, FieldDescriptor, \ - FileDescriptor -from google.protobuf.message import DecodeError, Message -from google.protobuf.message_factory import GetMessageClass - -from . import (_MAGIC_BYTE, - reference_subject_name_strategy, - topic_subject_name_strategy, SchemaRegistryClient) -from .confluent.types import decimal_pb2 -from .rule_registry import RuleRegistry -from .schema_registry_client import (Schema, - SchemaReference, - RuleKind, - RuleMode) -from confluent_kafka.serialization import SerializationError, \ - SerializationContext -from .serde import BaseSerializer, BaseDeserializer, RuleContext, \ - FieldTransform, FieldType, RuleConditionError, ParsedSchemaCache - -# Convert an int to bytes (inverse of ord()) -# Python3.chr() -> Unicode -# Python2.chr() -> str(alias for bytes) -if sys.version > '3': - def _bytes(v: int) -> bytes: - """ - Convert int to bytes - - Args: - v (int): The int to convert to bytes. - """ - return bytes((v,)) -else: - def _bytes(v: int) -> str: - """ - Convert int to bytes - - Args: - v (int): The int to convert to bytes. - """ - return chr(v) - - -class _ContextStringIO(io.BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _create_index_array(msg_desc: Descriptor) -> List[int]: - """ - Creates an index array specifying the location of msg_desc in - the referenced FileDescriptor. - - Args: - msg_desc (MessageDescriptor): Protobuf MessageDescriptor - - Returns: - list of int: Protobuf MessageDescriptor index array. - - Raises: - ValueError: If the message descriptor is malformed. - """ - - msg_idx = deque() - - # Walk the nested MessageDescriptor tree up to the root. - current = msg_desc - found = False - while current.containing_type is not None: - previous = current - current = previous.containing_type - # find child's position - for idx, node in enumerate(current.nested_types): - if node == previous: - msg_idx.appendleft(idx) - found = True - break - if not found: - raise ValueError("Nested MessageDescriptor not found") - - # Add the index of the root MessageDescriptor in the FileDescriptor. - found = False - for idx, msg_type_name in enumerate(msg_desc.file.message_types_by_name): - if msg_type_name == current.name: - msg_idx.appendleft(idx) - found = True - break - if not found: - raise ValueError("MessageDescriptor not found in file") - - return list(msg_idx) - - -def _schema_to_str(file_descriptor: FileDescriptor) -> str: - """ - Base64 encode a FileDescriptor - - Args: - file_descriptor (FileDescriptor): FileDescriptor to encode. - - Returns: - str: Base64 encoded FileDescriptor - """ - - return base64.standard_b64encode(file_descriptor.serialized_pb).decode('ascii') - - -def _proto_to_str(file_descriptor_proto: descriptor_pb2.FileDescriptorProto) -> str: - """ - Base64 encode a FileDescriptorProto - - Args: - file_descriptor_proto (FileDescriptorProto): FileDescriptorProto to encode. - - Returns: - str: Base64 encoded FileDescriptorProto - """ - - return base64.standard_b64encode(file_descriptor_proto.SerializeToString()).decode('ascii') - - -def _str_to_proto(name: str, schema_str: str) -> descriptor_pb2.FileDescriptorProto: - """ - Base64 decode a FileDescriptor - - Args: - schema_str (str): Base64 encoded FileDescriptorProto - - Returns: - FileDescriptorProto: schema. - """ - - serialized_pb = base64.standard_b64decode(schema_str.encode('ascii')) - file_descriptor_proto = descriptor_pb2.FileDescriptorProto() - try: - file_descriptor_proto.ParseFromString(serialized_pb) - file_descriptor_proto.name = name - except DecodeError as e: - raise SerializationError(str(e)) - return file_descriptor_proto - - -def _resolve_named_schema( - schema: Schema, - schema_registry_client: SchemaRegistryClient, - pool: DescriptorPool, - visited: Optional[Set[str]] = None -): - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :param pool: DescriptorPool to add resolved schemas to. - :return: DescriptorPool - """ - if visited is None: - visited = set() - if schema.references is not None: - for ref in schema.references: - if _is_builtin(ref.name) or ref.name in visited: - continue - visited.add(ref.name) - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') - _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) - pool.Add(file_descriptor_proto) - - -def _init_pool(pool: DescriptorPool): - pool.AddSerializedFile(any_pb2.DESCRIPTOR.serialized_pb) - # source_context needed by api - pool.AddSerializedFile(source_context_pb2.DESCRIPTOR.serialized_pb) - # type needed by api - pool.AddSerializedFile(type_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(api_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(descriptor_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(empty_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(field_mask_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(wrappers_pb2.DESCRIPTOR.serialized_pb) - - pool.AddSerializedFile(calendar_period_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(color_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(date_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(datetime_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(dayofweek_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(expr_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(fraction_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(latlng_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(money_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(month_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(postal_address_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(quaternion_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(timeofday_pb2.DESCRIPTOR.serialized_pb) - - pool.AddSerializedFile(meta_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(decimal_pb2.DESCRIPTOR.serialized_pb) - - -class ProtobufSerializer(BaseSerializer): - """ - Serializer for Protobuf Message derived classes. Serialization format is Protobuf, - with Confluent Schema Registry framing. - - Configuration properties: - - +-------------------------------------+----------+------------------------------------------------------+ - | Property Name | Type | Description | - +=====================================+==========+======================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - | | | | - | | | Raises SchemaRegistryError if the schema was not | - | | | registered against the subject, or could not be | - | | | successfully registered. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata`` | dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether or not to skip known types when resolving | - | ``skip.known.types`` | bool | schema dependencies. | - | | | | - | | | Defaults to True. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | - | | | references are constructed. | - | | | | - | | | Defaults to reference_subject_name_strategy | - +-------------------------------------+----------+------------------------------------------------------+ - | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | - | | | serialize message indexes without zig-zag encoding. | - | | | This option must be explicitly configured as older | - | | | and newer Protobuf producers are incompatible. | - | | | If the consumers of the topic being produced to are | - | | | using confluent-kafka-python <1.8 then this property | - | | | must be set to True until all old consumers have | - | | | have been upgraded. | - | | | | - | | | Warning: This configuration property will be removed | - | | | in a future version of the client. | - +-------------------------------------+----------+------------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Args: - msg_type (Message): Protobuf Message type. - - schema_registry_client (SchemaRegistryClient): Schema Registry - client instance. - - conf (dict): ProtobufSerializer configuration. - - See Also: - `Protobuf API reference `_ - """ # noqa: E501 - __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', - '_schema', '_schema_id', '_ref_reference_subject_func', - '_use_deprecated_format', '_parsed_schemas'] - - _default_conf = { - 'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'skip.known.types': True, - 'subject.name.strategy': topic_subject_name_strategy, - 'reference.subject.name.strategy': reference_subject_name_strategy, - 'use.deprecated.format': False, - } - - def __init__( - self, - msg_type: Message, - schema_registry_client: SchemaRegistryClient, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufSerializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._skip_known_types = conf_copy.pop('skip.known.types') - if not isinstance(self._skip_known_types, bool): - raise ValueError("skip.known.types must be a boolean value") - - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') - if not isinstance(self._use_deprecated_format, bool): - raise ValueError("use.deprecated.format must be a boolean value") - if self._use_deprecated_format: - warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " - "configuration property, and the ability to use the " - "old incorrect Protobuf serializer heading format " - "introduced in confluent-kafka-python v1.4.0, " - "will be removed in an upcoming release in 2021 Q2. " - "Please migrate your Python Protobuf producers and " - "consumers to 'use.deprecated.format':False as " - "soon as possible") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') - if not callable(self._ref_reference_subject_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None - self._known_subjects = set() - self._msg_class = msg_type - self._parsed_schemas = ParsedSchemaCache() - - descriptor = msg_type.DESCRIPTOR - self._index_array = _create_index_array(descriptor) - self._schema = Schema(_schema_to_str(descriptor.file), - schema_type='PROTOBUF') - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - @staticmethod - def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): - """ - Writes val to buf, either using zigzag or uvarint encoding. - - Args: - buf (BytesIO): buffer to write to. - val (int): integer to be encoded. - zigzag (bool): whether to encode in zigzag or uvarint encoding - """ - - if zigzag: - val = (val << 1) ^ (val >> 63) - - while (val & ~0x7f) != 0: - buf.write(_bytes((val & 0x7f) | 0x80)) - val >>= 7 - buf.write(_bytes(val)) - - @staticmethod - def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): - """ - Encodes each int as a uvarint onto buf - - Args: - buf (BytesIO): buffer to write to. - ints ([int]): ints to be encoded. - zigzag (bool): whether to encode in zigzag or uvarint encoding - """ - - assert len(ints) > 0 - # The root element at the 0 position does not need a length prefix. - if ints == [0]: - buf.write(_bytes(0x00)) - return - - ProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) - - for value in ints: - ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - - def _resolve_dependencies( - self, ctx: SerializationContext, - file_desc: FileDescriptor - ) -> List[SchemaReference]: - """ - Resolves and optionally registers schema references recursively. - - Args: - ctx (SerializationContext): Serialization context. - - file_desc (FileDescriptor): file descriptor to traverse. - """ - - schema_refs = [] - for dep in file_desc.dependencies: - if self._skip_known_types and _is_builtin(dep.name): - continue - dep_refs = self._resolve_dependencies(ctx, dep) - subject = self._ref_reference_subject_func(ctx, dep) - schema = Schema(_schema_to_str(dep), - references=dep_refs, - schema_type='PROTOBUF') - if self._auto_register: - self._registry.register_schema(subject, schema) - - reference = self._registry.lookup_schema(subject, schema) - # schema_refs are per file descriptor - schema_refs.append(SchemaReference(dep.name, - subject, - reference.version)) - return schema_refs - - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an instance of a class derived from Protobuf Message, and prepends - it with Confluent Schema Registry framing. - - Args: - message (Message): An instance of a class derived from Protobuf Message. - - ctx (SerializationContext): Metadata relevant to the serialization. - operation. - - Raises: - SerializerError if any error occurs during serialization. - - Returns: - None if messages is None, else a byte array containing the Protobuf - serialized message with Confluent Schema Registry framing. - """ - - if message is None: - return None - - if not isinstance(message, self._msg_class): - raise ValueError("message must be of type {} not {}" - .format(self._msg_class, type(message))) - - subject = self._subject_name_func(ctx, - message.DESCRIPTOR.full_name) - latest_schema = self._get_reader_schema(subject, fmt='serialized') - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - references = self._resolve_dependencies( - ctx, message.DESCRIPTOR.file) - self._schema = Schema( - self._schema.schema_str, - self._schema.schema_type, - references - ) - - if self._auto_register: - self._schema_id = self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) - else: - self._schema_id = self._registry.lookup_schema( - subject, self._schema, self._normalize_schemas).schema_id - - self._known_subjects.add(subject) - - if latest_schema is not None: - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) - fd = pool.FindFileByName(fd_proto.name) - desc = fd.message_types_by_name[message.DESCRIPTOR.name] - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order - # (big endian) - fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) - # write the index array that specifies the message descriptor - # of the serialized data. - self._encode_varints(fo, self._index_array, - zigzag=not self._use_deprecated_format) - # write the serialized data itself - fo.write(message.SerializeToString()) - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - pool = DescriptorPool() - _init_pool(pool) - _resolve_named_schema(schema, self._registry, pool) - fd_proto = _str_to_proto("default", schema.schema_str) - pool.Add(fd_proto) - self._parsed_schemas.set(schema, (fd_proto, pool)) - return fd_proto, pool - - -class ProtobufDeserializer(BaseDeserializer): - """ - Deserializer for Protobuf serialized data with Confluent Schema Registry framing. - - Args: - message_type (Message derived type): Protobuf Message type. - conf (dict): Configuration dictionary. - - ProtobufDeserializer configuration properties: - - +-------------------------------------+----------+------------------------------------------------------+ - | Property Name | Type | Description | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata`` | dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka. schema_registry | - | | | namespace . | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-------------------------------------+----------+------------------------------------------------------+ - | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | - | | | deserialize message indexes without zig-zag encoding.| - | | | This option must be explicitly configured as older | - | | | and newer Protobuf producers are incompatible. | - | | | If Protobuf messages in the topic to consume were | - | | | produced with confluent-kafka-python <1.8 then this | - | | | property must be set to True until all old messages | - | | | have been processed and producers have been upgraded.| - | | | Warning: This configuration property will be removed | - | | | in a future version of the client. | - +-------------------------------------+----------+------------------------------------------------------+ - - - See Also: - `Protobuf API reference `_ - """ - - __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] - - _default_conf = { - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'use.deprecated.format': False, - } - - def __init__( - self, - message_type: Message, - conf: Optional[dict] = None, - schema_registry_client: Optional[SchemaRegistryClient] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._use_schema_id = None - - # Require use.deprecated.format to be explicitly configured - # during a transitionary period since old/new format are - # incompatible. - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufDeserializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') - if not isinstance(self._use_deprecated_format, bool): - raise ValueError("use.deprecated.format must be a boolean value") - if self._use_deprecated_format: - warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " - "configuration property, and the ability to use the " - "old incorrect Protobuf serializer heading format " - "introduced in confluent-kafka-python v1.4.0, " - "will be removed in an upcoming release in 2022 Q2. " - "Please migrate your Python Protobuf producers and " - "consumers to 'use.deprecated.format':False as " - "soon as possible") - - descriptor = message_type.DESCRIPTOR - self._msg_class = GetMessageClass(descriptor) - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - @staticmethod - def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: - """ - Decodes a single varint from a buffer. - - Args: - buf (BytesIO): buffer to read from - zigzag (bool): decode as zigzag or uvarint - - Returns: - int: decoded varint - - Raises: - EOFError: if buffer is empty - """ - - value = 0 - shift = 0 - try: - while True: - i = ProtobufDeserializer._read_byte(buf) - - value |= (i & 0x7f) << shift - shift += 7 - if not (i & 0x80): - break - - if zigzag: - value = (value >> 1) ^ -(value & 1) - - return value - - except EOFError: - raise EOFError("Unexpected EOF while reading index") - - @staticmethod - def _read_byte(buf: io.BytesIO) -> int: - """ - Read one byte from buf as an int. - - Args: - buf (BytesIO): The buffer to read from. - - .. _ord: - https://docs.python.org/2/library/functions.html#ord - """ - - i = buf.read(1) - if i == b'': - raise EOFError("Unexpected EOF encountered") - return ord(i) - - @staticmethod - def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: - """ - Read an index array from buf that specifies the message - descriptor of interest in the file descriptor. - - Args: - buf (BytesIO): The buffer to read from. - - Returns: - list of int: The index array. - """ - - size = ProtobufDeserializer._decode_varint(buf, zigzag=zigzag) - if size < 0 or size > 100000: - raise DecodeError("Invalid Protobuf msgidx array length") - - if size == 0: - return [0] - - msg_index = [] - for _ in range(size): - msg_index.append(ProtobufDeserializer._decode_varint(buf, - zigzag=zigzag)) - - return msg_index - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[Message]: - """ - Deserialize a serialized protobuf message with Confluent Schema Registry - framing. - - Args: - data (bytes): Serialized protobuf message with Confluent Schema - Registry framing. - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Returns: - Message: Protobuf Message instance. - - Raises: - SerializerError: If there was an error reading the Confluent framing - data, or parsing the protobuf serialized message. - """ - - if data is None: - return None - - # SR wire protocol + msg_index length - if len(data) < 6: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) - latest_schema = None - if subject is not None and self._registry is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') - - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unknown magic byte. This message was " - "not produced with a Confluent " - "Schema Registry serializer") - - msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) - - if self._registry is not None: - writer_schema_raw = self._registry.get_schema(schema_id, fmt='serialized') - fd_proto, pool = self._get_parsed_schema(writer_schema_raw) - writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) - if subject is None: - subject = self._subject_name_func(ctx, writer_desc.full_name) - if subject is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') - else: - writer_schema_raw = None - writer_schema = None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) - reader_schema = pool.FindFileByName(fd_proto.name) - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if reader_schema is not None: - # Initialize reader desc to first message in file - reader_desc = self._get_message_desc(pool, reader_schema, [0]) - # Attempt to find a reader desc with the same name as the writer - reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - - if migrations: - msg = GetMessageClass(writer_desc)() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - obj_dict = json_format.MessageToDict(msg, True) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - msg = GetMessageClass(reader_desc)() - msg = json_format.ParseDict(obj_dict, msg) - else: - # Protobuf Messages are self-describing; no need to query schema - msg = self._msg_class() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) - - return msg - - def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - pool = DescriptorPool() - _init_pool(pool) - _resolve_named_schema(schema, self._registry, pool) - fd_proto = _str_to_proto("default", schema.schema_str) - pool.Add(fd_proto) - self._parsed_schemas.set(schema, (fd_proto, pool)) - return fd_proto, pool - - def _get_message_desc( - self, pool: DescriptorPool, fd: FileDescriptor, - msg_index: List[int] - ) -> Descriptor: - file_desc_proto = descriptor_pb2.FileDescriptorProto() - fd.CopyToProto(file_desc_proto) - (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) - package = file_desc_proto.package - qualified_name = package + "." + full_name if package else full_name - return pool.FindMessageTypeByName(qualified_name) - - def _get_message_desc_proto( - self, - path: str, - desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], - msg_index: List[int] - ) -> Tuple[str, descriptor_pb2.DescriptorProto]: - index = msg_index[0] - if isinstance(desc, descriptor_pb2.FileDescriptorProto): - msg = desc.message_type[index] - path = path + "." + msg.name if path else msg.name - if len(msg_index) == 1: - return path, msg - return self._get_message_desc_proto(path, msg, msg_index[1:]) - else: - msg = desc.nested_type[index] - path = path + "." + msg.name if path else msg.name - if len(msg_index) == 1: - return path, msg - return self._get_message_desc_proto(path, msg, msg_index[1:]) - - -def transform( - ctx: RuleContext, descriptor: Descriptor, message: Any, - field_transform: FieldTransform -) -> Any: - if message is None or descriptor is None: - return message - if isinstance(message, list): - return [transform(ctx, descriptor, item, field_transform) - for item in message] - if isinstance(message, dict): - return {key: transform(ctx, descriptor, value, field_transform) - for key, value in message.items()} - if isinstance(message, Message): - for fd in descriptor.fields: - _transform_field(ctx, fd, descriptor, message, field_transform) - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, fd: FieldDescriptor, desc: Descriptor, - message: Message, field_transform: FieldTransform -): - try: - ctx.enter_field( - message, - fd.full_name, - fd.name, - get_type(fd), - get_inline_tags(fd) - ) - if fd.containing_oneof is not None and not message.HasField(fd.name): - return - value = getattr(message, fd.name) - if is_map_field(fd): - value = {key: value[key] for key in value} - elif fd.label == FieldDescriptor.LABEL_REPEATED: - value = [item for item in value] - new_value = transform(ctx, desc, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - _set_field(fd, message, new_value) - finally: - ctx.exit_field() - - -def _set_field(fd: FieldDescriptor, message: Message, value: Any): - if isinstance(value, list): - message.ClearField(fd.name) - old_value = getattr(message, fd.name) - old_value.extend(value) - elif isinstance(value, dict): - message.ClearField(fd.name) - old_value = getattr(message, fd.name) - old_value.update(value) - else: - setattr(message, fd.name, value) - - -def get_type(fd: FieldDescriptor) -> FieldType: - if is_map_field(fd): - return FieldType.MAP - if fd.type == FieldDescriptor.TYPE_MESSAGE: - return FieldType.RECORD - if fd.type == FieldDescriptor.TYPE_ENUM: - return FieldType.ENUM - if fd.type == FieldDescriptor.TYPE_STRING: - return FieldType.STRING - if fd.type == FieldDescriptor.TYPE_BYTES: - return FieldType.BYTES - if fd.type in (FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_SINT32, - FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, - FieldDescriptor.TYPE_SFIXED32): - return FieldType.INT - if fd.type in (FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_SINT64, - FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, - FieldDescriptor.TYPE_SFIXED64): - return FieldType.LONG - if fd.type == FieldDescriptor.TYPE_FLOAT: - return FieldType.FLOAT - if fd.type == FieldDescriptor.TYPE_DOUBLE: - return FieldType.DOUBLE - if fd.type == FieldDescriptor.TYPE_BOOL: - return FieldType.BOOLEAN - return FieldType.NULL - - -def is_map_field(fd: FieldDescriptor): - return (fd.type == FieldDescriptor.TYPE_MESSAGE - and hasattr(fd.message_type, 'options') - and fd.message_type.options.map_entry) - - -def get_inline_tags(fd: FieldDescriptor) -> Set[str]: - meta = fd.GetOptions().Extensions[meta_pb2.field_meta] - if meta is None: - return set() - else: - return set(meta.tags) - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def _is_builtin(name: str) -> bool: - return name.startswith('confluent/') or \ - name.startswith('google/protobuf/') or \ - name.startswith('google/type/') - - -def decimalToProtobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: - """ - Converts a Decimal to a Protobuf value. - - Args: - value (Decimal): The Decimal value to convert. - - Returns: - The Protobuf value. - """ - sign, digits, exp = value.as_tuple() - - delta = exp + scale - - if delta < 0: - raise ValueError( - "Scale provided does not match the decimal") - - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit - - unscaled_datum = 10**delta * unscaled_datum - - bytes_req = (unscaled_datum.bit_length() + 8) // 8 - - if sign: - unscaled_datum = -unscaled_datum - - bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) - - result = decimal_pb2.Decimal() - result.value = bytes - result.precision = 0 - result.scale = scale - return result - - -decimal_context = Context() - - -def protobufToDecimal(value: decimal_pb2.Decimal) -> Decimal: - """ - Converts a Protobuf value to Decimal. - - Args: - value (decimal_pb2.Decimal): The Protobuf value to convert. - - Returns: - The Decimal value. - """ - unscaled_datum = int.from_bytes(value.value, byteorder="big", signed=True) - - if value.precision > 0: - decimal_context.prec = value.precision - else: - decimal_context.prec = MAX_PREC - return decimal_context.create_decimal(unscaled_datum).scaleb( - -value.scale, decimal_context - ) +from .common.protobuf import * +from ._sync.protobuf import * diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index 4cadf8bfd..1763206ff 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -14,1968 +14,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import abc -import json -import logging -import random -import time -import urllib -from urllib.parse import unquote, urlparse - -import httpx -from attrs import define as _attrs_define -from attrs import field as _attrs_field -from collections import defaultdict -from enum import Enum -from threading import Lock -from typing import List, Dict, Type, TypeVar, \ - cast, Optional, Union, Any, Tuple, Callable - -from cachetools import TTLCache, LRUCache -from httpx import Response - -from authlib.integrations.httpx_client import OAuth2Client - -from .error import SchemaRegistryError, OAuthTokenError - -# TODO: consider adding `six` dependency or employing a compat file -# Python 2.7 is officially EOL so compatibility issue will be come more the norm. -# We need a better way to handle these issues. -# Six is one possibility but the compat file pattern used by requests -# is also quite nice. -# -# six: https://pypi.org/project/six/ -# compat file : https://github.com/psf/requests/blob/master/requests/compat.py -try: - string_type = basestring # noqa - - def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') -except NameError: - string_type = str - - def _urlencode(value: str) -> str: - return urllib.parse.quote(value, safe='') - -log = logging.getLogger(__name__) -VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] - - -class _BearerFieldProvider(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_bearer_fields(self) -> dict: - raise NotImplementedError - - -class _StaticFieldProvider(_BearerFieldProvider): - def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - -class _CustomOAuthClient(_BearerFieldProvider): - def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): - self.custom_function = custom_function - self.custom_config = custom_config - - def get_bearer_fields(self) -> dict: - return self.custom_function(self.custom_config) - - -class _OAuthClient(_BearerFieldProvider): - def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): - self.token = None - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) - self.token_endpoint = token_endpoint - self.max_retries = max_retries - self.retries_wait_ms = retries_wait_ms - self.retries_max_wait_ms = retries_max_wait_ms - self.token_expiry_threshold = 0.8 - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold - - return self.token['expires_at'] < time.time() + expiry_window - - def get_access_token(self) -> str: - if not self.token or self.token_expired(): - self.generate_access_token() - - return self.token['access_token'] - - def generate_access_token(self) -> None: - for i in range(self.max_retries + 1): - try: - self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') - return - except Exception as e: - if i >= self.max_retries: - raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " - f"attempts due to error: {str(e)}") - time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - - -class _BaseRestClient(object): - - def __init__(self, conf: dict): - # copy dict to avoid mutating the original - conf_copy = conf.copy() - - base_url = conf_copy.pop('url', None) - if base_url is None: - raise ValueError("Missing required configuration property url") - if not isinstance(base_url, string_type): - raise TypeError("url must be a str, not " + str(type(base_url))) - base_urls = [] - for url in base_url.split(','): - url = url.strip().rstrip('/') - if not url.startswith('http') and not url.startswith('mock'): - raise ValueError("Invalid url {}".format(url)) - base_urls.append(url) - if not base_urls: - raise ValueError("Missing required configuration property url") - self.base_urls = base_urls - - self.verify = True - ca = conf_copy.pop('ssl.ca.location', None) - if ca is not None: - self.verify = ca - - key: Optional[str] = conf_copy.pop('ssl.key.location', None) - client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) - self.cert: Union[str, Tuple[str, str], None] = None - - if client_cert is not None and key is not None: - self.cert = (client_cert, key) - - if client_cert is not None and key is None: - self.cert = client_cert - - if key is not None and client_cert is None: - raise ValueError("ssl.certificate.location required when" - " configuring ssl.key.location") - - parsed = urlparse(self.base_urls[0]) - try: - userinfo = (unquote(parsed.username), unquote(parsed.password)) - except (AttributeError, TypeError): - userinfo = ("", "") - if 'basic.auth.user.info' in conf_copy: - if userinfo != ('', ''): - raise ValueError("basic.auth.user.info configured with" - " userinfo credentials in the URL." - " Remove userinfo credentials from the url or" - " remove basic.auth.user.info from the" - " configuration") - - userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) - - if len(userinfo) != 2: - raise ValueError("basic.auth.user.info must be in the form" - " of {username}:{password}") - - self.auth = userinfo if userinfo != ('', '') else None - - # The following adds support for proxy config - # If specified: it uses the specified proxy details when making requests - self.proxy = None - proxy = conf_copy.pop('proxy', None) - if proxy is not None: - self.proxy = proxy - - self.timeout = None - timeout = conf_copy.pop('timeout', None) - if timeout is not None: - self.timeout = timeout - - self.cache_capacity = 1000 - cache_capacity = conf_copy.pop('cache.capacity', None) - if cache_capacity is not None: - if not isinstance(cache_capacity, (int, float)): - raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity - - self.cache_latest_ttl_sec = None - cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) - if cache_latest_ttl_sec is not None: - if not isinstance(cache_latest_ttl_sec, (int, float)): - raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) - self.cache_latest_ttl_sec = cache_latest_ttl_sec - - self.max_retries = 3 - max_retries = conf_copy.pop('max.retries', None) - if max_retries is not None: - if not isinstance(max_retries, (int, float)): - raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries - - self.retries_wait_ms = 1000 - retries_wait_ms = conf_copy.pop('retries.wait.ms', None) - if retries_wait_ms is not None: - if not isinstance(retries_wait_ms, (int, float)): - raise TypeError("retries.wait.ms must be a number, not " - + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms - - self.retries_max_wait_ms = 20000 - retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) - if retries_max_wait_ms is not None: - if not isinstance(retries_max_wait_ms, (int, float)): - raise TypeError("retries.max.wait.ms must be a number, not " - + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms - - self.bearer_field_provider = None - logical_cluster = None - identity_pool = None - self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) - if self.bearer_auth_credentials_source is not None: - self.auth = None - - if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError("Missing required bearer configuration properties: {}" - .format(", ".join(missing_headers))) - - logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') - if not isinstance(logical_cluster, str): - raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) - elif self.bearer_auth_credentials_source == 'CUSTOM': - custom_bearer_properties = ['bearer.auth.custom.provider.function', - 'bearer.auth.custom.provider.config'] - missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] - if missing_custom_properties: - raise ValueError("Missing required custom OAuth configuration properties: {}". - format(", ".join(missing_custom_properties))) - - custom_function = conf_copy.pop('bearer.auth.custom.provider.function') - if not callable(custom_function): - raise TypeError("bearer.auth.custom.provider.function must be a callable, not " - + str(type(custom_function))) - - custom_config = conf_copy.pop('bearer.auth.custom.provider.config') - if not isinstance(custom_config, dict): - raise TypeError("bearer.auth.custom.provider.config must be a dict, not " - + str(type(custom_config))) - - self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) - else: - raise ValueError('Unrecognized bearer.auth.credentials.source') - - # Any leftover keys are unknown to _RestClient - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - def get(self, url: str, query: Optional[dict] = None) -> Any: - raise NotImplementedError() - - def post(self, url: str, body: Optional[dict], **kwargs) -> Any: - raise NotImplementedError() - - def delete(self, url: str) -> Any: - raise NotImplementedError() - - def put(self, url: str, body: Optional[dict] = None) -> Any: - raise NotImplementedError() - - -class _RestClient(_BaseRestClient): - """ - HTTP client for Confluent Schema Registry. - - See SchemaRegistryClient for configuration details. - - Args: - conf (dict): Dictionary containing _RestClient configuration - """ - - def __init__(self, conf: dict): - super().__init__(conf) - - self.session = httpx.Client( - verify=self.verify, - cert=self.cert, - auth=self.auth, - proxy=self.proxy, - timeout=self.timeout - ) - - def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] - - missing_fields = [] - for field in required_fields: - if field not in bearer_fields: - missing_fields.append(field) - - if missing_fields: - raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" - .format(", ".join(missing_fields))) - - headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] - headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] - - def get(self, url: str, query: Optional[dict] = None) -> Any: - return self.send_request(url, method='GET', query=query) - - def post(self, url: str, body: Optional[dict], **kwargs) -> Any: - return self.send_request(url, method='POST', body=body) - - def delete(self, url: str) -> Any: - return self.send_request(url, method='DELETE') - - def put(self, url: str, body: Optional[dict] = None) -> Any: - return self.send_request(url, method='PUT', body=body) - - def send_request( - self, url: str, method: str, body: Optional[dict] = None, - query: Optional[dict] = None - ) -> Any: - """ - Sends HTTP request to the SchemaRegistry, trying each base URL in turn. - - All unsuccessful attempts will raise a SchemaRegistryError with the - response contents. In most cases this will be accompanied by a - Schema Registry supplied error code. - - In the event the response is malformed an error_code of -1 will be used. - - Args: - url (str): Request path - - method (str): HTTP method - - body (str): Request content - - query (dict): Query params to attach to the URL - - Returns: - dict: Schema Registry response content. - """ - - headers = {'Accept': "application/vnd.schemaregistry.v1+json," - " application/vnd.schemaregistry+json," - " application/json"} - - if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), - 'Content-Type': "application/vnd.schemaregistry.v1+json"} - - if self.bearer_auth_credentials_source: - self.handle_bearer_auth(headers) - - response = None - for i, base_url in enumerate(self.base_urls): - try: - response = self.send_http_request( - base_url, url, method, headers, body, query) - - if is_success(response.status_code): - return response.json() - - if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: - break - except Exception as e: - if i == len(self.base_urls) - 1: - # Raise the exception since we have no more urls to try - raise e - - try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, - -1, - "Unknown Schema Registry Error: " - + str(response.content)) - - def send_http_request( - self, base_url: str, url: str, method: str, headers: Optional[dict], - body: Optional[str] = None, query: Optional[dict] = None - ) -> Response: - """ - Sends HTTP request to the SchemaRegistry. - - All unsuccessful attempts will raise a SchemaRegistryError with the - response contents. In most cases this will be accompanied by a - Schema Registry supplied error code. - - In the event the response is malformed an error_code of -1 will be used. - - Args: - base_url (str): Schema Registry base URL - - url (str): Request path - - method (str): HTTP method - - headers (dict): Headers - - body (str): Request content - - query (dict): Query params to attach to the URL - - Returns: - Response: Schema Registry response content. - """ - response = None - for i in range(self.max_retries + 1): - response = self.session.request( - method, url="/".join([base_url, url]), - headers=headers, content=body, params=query) - - if is_success(response.status_code): - return response - - if not is_retriable(response.status_code) or i >= self.max_retries: - return response - - time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response - - -def is_success(status_code: int) -> bool: - return 200 <= status_code <= 299 - - -def is_retriable(status_code: int) -> bool: - return status_code in (408, 429, 500, 502, 503, 504) - - -def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: - no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) - return random.random() * min(no_jitter_delay, max_delay_ms) - - -class _SchemaCache(object): - """ - Thread-safe cache for use with the Schema Registry Client. - - This cache may be used to retrieve schema ids, schemas or to check - known subject membership. - """ - - def __init__(self): - self.lock = Lock() - self.schema_id_index = defaultdict(dict) - self.schema_index = defaultdict(dict) - self.rs_id_index = defaultdict(dict) - self.rs_version_index = defaultdict(dict) - self.rs_schema_index = defaultdict(dict) - - def set_schema(self, subject: str, schema_id: int, schema: 'Schema'): - """ - Add a Schema identified by schema_id to the cache. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): Schema's registration id - - schema (Schema): Schema instance - """ - - with self.lock: - self.schema_id_index[subject][schema_id] = schema - self.schema_index[subject][schema] = schema_id - - def set_registered_schema(self, schema: 'Schema', registered_schema: 'RegisteredSchema'): - """ - Add a RegisteredSchema to the cache. - - Args: - registered_schema (RegisteredSchema): RegisteredSchema instance - """ - - subject = registered_schema.subject - schema_id = registered_schema.schema_id - version = registered_schema.version - with self.lock: - self.schema_id_index[subject][schema_id] = schema - self.schema_index[subject][schema] = schema_id - self.rs_id_index[subject][schema_id] = registered_schema - self.rs_version_index[subject][version] = registered_schema - self.rs_schema_index[subject][schema] = registered_schema - - def get_schema_by_id(self, subject: str, schema_id: int) -> Optional['Schema']: - """ - Get the schema instance associated with schema id from the cache. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): Id used to identify a schema - - Returns: - Schema: The schema if known; else None - """ - - with self.lock: - return self.schema_id_index.get(subject, {}).get(schema_id, None) - - def get_id_by_schema(self, subject: str, schema: 'Schema') -> Optional[int]: - """ - Get the schema id associated with schema instance from the cache. - - Args: - subject (str): The subject this schema is associated with - - schema (Schema): The schema - - Returns: - int: The schema id if known; else None - """ - - with self.lock: - return self.schema_index.get(subject, {}).get(schema, None) - - def get_registered_by_subject_schema(self, subject: str, schema: 'Schema') -> Optional['RegisteredSchema']: - """ - Get the schema associated with this schema registered under subject. - - Args: - subject (str): The subject this schema is associated with - - schema (Schema): The schema associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_schema_index.get(subject, {}).get(schema, None) - - def get_registered_by_subject_id(self, subject: str, schema_id: int) -> Optional['RegisteredSchema']: - """ - Get the schema associated with this id registered under subject. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): The schema id associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_id_index.get(subject, {}).get(schema_id, None) - - def get_registered_by_subject_version(self, subject: str, version: int) -> Optional['RegisteredSchema']: - """ - Get the schema associated with this version registered under subject. - - Args: - subject (str): The subject this schema is associated with - - version (int): The version associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_version_index.get(subject, {}).get(version, None) - - def remove_by_subject(self, subject: str): - """ - Remove schemas with the given subject. - - Args: - subject (str): The subject - """ - - with self.lock: - if subject in self.schema_id_index: - del self.schema_id_index[subject] - if subject in self.schema_index: - del self.schema_index[subject] - if subject in self.rs_id_index: - del self.rs_id_index[subject] - if subject in self.rs_version_index: - del self.rs_version_index[subject] - if subject in self.rs_schema_index: - del self.rs_schema_index[subject] - - def remove_by_subject_version(self, subject: str, version: int): - """ - Remove schemas with the given subject. - - Args: - subject (str): The subject - - version (int) The version - """ - - with self.lock: - if subject in self.rs_id_index: - for schema_id, registered_schema in self.rs_id_index[subject].items(): - if registered_schema.version == version: - del self.rs_schema_index[subject][schema_id] - if subject in self.rs_schema_index: - for schema, registered_schema in self.rs_schema_index[subject].items(): - if registered_schema.version == version: - del self.rs_schema_index[subject][schema] - rs = None - if subject in self.rs_version_index: - if version in self.rs_version_index[subject]: - rs = self.rs_version_index[subject][version] - del self.rs_version_index[subject][version] - if rs is not None: - if subject in self.schema_id_index: - if rs.schema_id in self.schema_id_index[subject]: - del self.schema_id_index[subject][rs.schema_id] - if rs.schema in self.schema_index[subject]: - del self.schema_index[subject][rs.schema] - - def clear(self): - """ - Clear the cache. - """ - - with self.lock: - self.schema_id_index.clear() - self.schema_index.clear() - self.rs_id_index.clear() - self.rs_version_index.clear() - self.rs_schema_index.clear() - - -class SchemaRegistryClient(object): - """ - A Confluent Schema Registry client. - - Configuration properties (* indicates a required field): - - +------------------------------+------+-------------------------------------------------+ - | Property name | type | Description | - +==============================+======+=================================================+ - | ``url`` * | str | Comma-separated list of Schema Registry URLs. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to CA certificate file used | - | ``ssl.ca.location`` | str | to verify the Schema Registry's | - | | | private key. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to client's private key | - | | | (PEM) used for authentication. | - | ``ssl.key.location`` | str | | - | | | ``ssl.certificate.location`` must also be set. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to client's public key (PEM) used for | - | | | authentication. | - | ``ssl.certificate.location`` | str | | - | | | May be set without ssl.key.location if the | - | | | private key is stored within the PEM as well. | - +------------------------------+------+-------------------------------------------------+ - | | | Client HTTP credentials in the form of | - | | | ``username:password``. | - | ``basic.auth.user.info`` | str | | - | | | By default userinfo is extracted from | - | | | the URL if present. | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``proxy`` | str | Proxy such as http://localhost:8030. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``timeout`` | int | Request timeout. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | Maximum time to wait for the first retry. | - | | | When jitter is applied, the actual wait may | - | ``retries.wait.ms`` | int | be less. | - | | | | - | | | Defaults to 1000. | - +------------------------------+------+-------------------------------------------------+ - - Args: - conf (dict): Schema Registry client configuration. - - See Also: - `Confluent Schema Registry documentation `_ - """ # noqa: E501 - - def __init__(self, conf: dict): - self._conf = conf - self._rest_client = _RestClient(conf) - self._cache = _SchemaCache() - cache_capacity = self._rest_client.cache_capacity - cache_ttl = self._rest_client.cache_latest_ttl_sec - if cache_ttl is not None: - self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) - else: - self._latest_version_cache = LRUCache(cache_capacity) - self._latest_with_metadata_cache = LRUCache(cache_capacity) - - def __enter__(self): - return self - - def __exit__(self, *args): - if self._rest_client is not None: - self._rest_client.session.close() - - def config(self): - return self._conf - - def register_schema( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False - ) -> int: - """ - Registers a schema under ``subject_name``. - - Args: - subject_name (str): subject to register a schema under - schema (Schema): Schema instance to register - normalize_schemas (bool): Normalize schema before registering - - Returns: - int: Schema id - - Raises: - SchemaRegistryError: if Schema violates this subject's - Compatibility policy or is otherwise invalid. - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - registered_schema = self.register_schema_full_response(subject_name, schema, normalize_schemas) - return registered_schema.schema_id - - def register_schema_full_response( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False - ) -> 'RegisteredSchema': - """ - Registers a schema under ``subject_name``. - - Args: - subject_name (str): subject to register a schema under - schema (Schema): Schema instance to register - normalize_schemas (bool): Normalize schema before registering - - Returns: - int: Schema id - - Raises: - SchemaRegistryError: if Schema violates this subject's - Compatibility policy or is otherwise invalid. - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - schema_id = self._cache.get_id_by_schema(subject_name, schema) - if schema_id is not None: - return RegisteredSchema(schema_id, schema, subject_name, None) - - request = schema.to_dict() - - response = self._rest_client.post( - 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), - body=request) - - registered_schema = RegisteredSchema.from_dict(response) - - # The registered schema may not be fully populated - self._cache.set_schema(subject_name, registered_schema.schema_id, schema) - - return registered_schema - - def get_schema( - self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None - ) -> 'Schema': - """ - Fetches the schema associated with ``schema_id`` from the - Schema Registry. The result is cached so subsequent attempts will not - require an additional round-trip to the Schema Registry. - - Args: - schema_id (int): Schema id - subject_name (str): Subject name the schema is registered under - fmt (str): Format of the schema - - Returns: - Schema: Schema instance identified by the ``schema_id`` - - Raises: - SchemaRegistryError: If schema can't be found. - - See Also: - `GET Schema API Reference `_ - """ # noqa: E501 - - schema = self._cache.get_schema_by_id(subject_name, schema_id) - if schema is not None: - return schema - - query = {'subject': subject_name} if subject_name is not None else None - if fmt is not None: - if query is not None: - query['format'] = fmt - else: - query = {'format': fmt} - response = self._rest_client.get('schemas/ids/{}'.format(schema_id), query) - - schema = Schema.from_dict(response) - - self._cache.set_schema(subject_name, schema_id, schema) - - return schema - - def lookup_schema( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False, deleted: bool = False - ) -> 'RegisteredSchema': - """ - Returns ``schema`` registration information for ``subject``. - - Args: - subject_name (str): Subject name the schema is registered under - schema (Schema): Schema instance. - normalize_schemas (bool): Normalize schema before registering - deleted (bool): Whether to include deleted schemas. - - Returns: - RegisteredSchema: Subject registration information for this schema. - - Raises: - SchemaRegistryError: If schema or subject can't be found - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) - if registered_schema is not None: - return registered_schema - - request = schema.to_dict() - - response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' - .format(_urlencode(subject_name), normalize_schemas, deleted), - body=request) - - result = RegisteredSchema.from_dict(response) - - # Ensure the schema matches the input - registered_schema = RegisteredSchema( - schema_id=result.schema_id, - subject=result.subject, - version=result.version, - schema=schema, - ) - - self._cache.set_registered_schema(schema, registered_schema) - - return registered_schema - - def get_subjects(self) -> List[str]: - """ - List all subjects registered with the Schema Registry - - Returns: - list(str): Registered subject names - - Raises: - SchemaRegistryError: if subjects can't be found - - See Also: - `GET subjects API Reference `_ - """ # noqa: E501 - - return self._rest_client.get('subjects') - - def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: - """ - Deletes the specified subject and its associated compatibility level if - registered. It is recommended to use this API only when a topic needs - to be recycled or in development environments. - - Args: - subject_name (str): subject name - permanent (bool): True for a hard delete, False (default) for a soft delete - - Returns: - list(int): Versions deleted under this subject - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `DELETE Subject API Reference `_ - """ # noqa: E501 - - if permanent: - versions = self._rest_client.delete('subjects/{}?permanent=true' - .format(_urlencode(subject_name))) - self._cache.remove_by_subject(subject_name) - else: - versions = self._rest_client.delete('subjects/{}' - .format(_urlencode(subject_name))) - - return versions - - def get_latest_version( - self, subject_name: str, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves latest registered version for subject - - Args: - subject_name (str): Subject name. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - - See Also: - `GET Subject Version API Reference `_ - """ # noqa: E501 - - registered_schema = self._latest_version_cache.get(subject_name, None) - if registered_schema is not None: - return registered_schema - - query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - 'latest'), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._latest_version_cache[subject_name] = registered_schema - - return registered_schema - - def get_latest_with_metadata( - self, subject_name: str, metadata: Dict[str, str], - deleted: bool = False, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves latest registered version for subject with the given metadata - - Args: - subject_name (str): Subject name. - metadata (dict): The key-value pairs for the metadata. - deleted (bool): Whether to include deleted schemas. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - """ # noqa: E501 - - cache_key = (subject_name, frozenset(metadata.items()), deleted) - registered_schema = self._latest_with_metadata_cache.get(cache_key, None) - if registered_schema is not None: - return registered_schema - - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - keys = metadata.keys() - if keys: - query['key'] = [_urlencode(key) for key in keys] - query['value'] = [_urlencode(metadata[key]) for key in keys] - - response = self._rest_client.get('subjects/{}/metadata' - .format(_urlencode(subject_name)), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._latest_with_metadata_cache[cache_key] = registered_schema - - return registered_schema - - def get_version( - self, subject_name: str, version: int, - deleted: bool = False, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves a specific schema registered under ``subject_name``. - - Args: - subject_name (str): Subject name. - version (int): version number. Defaults to latest version. - deleted (bool): Whether to include deleted schemas. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - - See Also: - `GET Subject Version API Reference `_ - """ # noqa: E501 - - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) - if registered_schema is not None: - return registered_schema - - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._cache.set_registered_schema(registered_schema.schema, registered_schema) - - return registered_schema - - def get_versions(self, subject_name: str) -> List[int]: - """ - Get a list of all versions registered with this subject. - - Args: - subject_name (str): Subject name. - - Returns: - list(int): Registered versions - - Raises: - SchemaRegistryError: If subject can't be found - - See Also: - `GET Subject Versions API Reference `_ - """ # noqa: E501 - - return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) - - def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: - """ - Deletes a specific version registered to ``subject_name``. - - Args: - subject_name (str) Subject name - - version (int): Version number - - permanent (bool): True for a hard delete, False (default) for a soft delete - - Returns: - int: Version number which was deleted - - Raises: - SchemaRegistryError: if the subject or version cannot be found. - - See Also: - `Delete Subject Version API Reference `_ - """ # noqa: E501 - - if permanent: - response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' - .format(_urlencode(subject_name), - version)) - self._cache.remove_by_subject_version(subject_name, version) - else: - response = self._rest_client.delete('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version)) - - return response - - def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: - """ - Update global or subject level compatibility level. - - Args: - level (str): Compatibility level. See API reference for a list of - valid values. - - subject_name (str, optional): Subject to update. Sets global compatibility - level policy if not set. - - Returns: - str: The newly configured compatibility level. - - Raises: - SchemaRegistryError: If the compatibility level is invalid. - - See Also: - `PUT Subject Compatibility API Reference `_ - """ # noqa: E501 - - if level is None: - raise ValueError("level must be set") - - if subject_name is None: - return self._rest_client.put('config', - body={'compatibility': level.upper()}) - - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body={'compatibility': level.upper()}) - - def get_compatibility(self, subject_name: Optional[str] = None) -> str: - """ - Get the current compatibility level. - - Args: - subject_name (str, optional): Subject name. Returns global policy - if left unset. - - Returns: - str: Compatibility level for the subject if set, otherwise the global compatibility level. - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `GET Subject Compatibility API Reference `_ - """ # noqa: E501 - - if subject_name is not None: - url = 'config/{}'.format(_urlencode(subject_name)) - else: - url = 'config' - - result = self._rest_client.get(url) - return result['compatibilityLevel'] - - def test_compatibility( - self, subject_name: str, schema: 'Schema', - version: Union[int, str] = "latest" - ) -> bool: - """Test the compatibility of a candidate schema for a given subject and version - - Args: - subject_name (str): Subject name the schema is registered under - schema (Schema): Schema instance. - version (int or str, optional): Version number, or the string "latest". Defaults to "latest". - - Returns: - bool: True if the schema is compatible with the specified version - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `POST Test Compatibility API Reference `_ - """ # noqa: E501 - - request = schema.to_dict() - - response = self._rest_client.post( - 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request - ) - - return response['is_compatible'] - - def set_config( - self, subject_name: Optional[str] = None, - config: Optional['ServerConfig'] = None - ) -> 'ServerConfig': - """ - Update global or subject config. - - Args: - config (ServerConfig): Config. See API reference for a list of - valid values. - - subject_name (str, optional): Subject to update. Sets global config - if not set. - - Returns: - str: The newly configured config. - - Raises: - SchemaRegistryError: If the config is invalid. - - See Also: - `PUT Subject Config API Reference `_ - """ # noqa: E501 - - if config is None: - raise ValueError("config must be set") - - if subject_name is None: - return self._rest_client.put('config', - body=config.to_dict()) - - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body=config.to_dict()) - - def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': - """ - Get the current config. - - Args: - subject_name (str, optional): Subject name. Returns global config - if left unset. - - Returns: - ServerConfig: Config for the subject if set, otherwise the global config. - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `GET Subject Config API Reference `_ - """ # noqa: E501 - - if subject_name is not None: - url = 'config/{}'.format(_urlencode(subject_name)) - else: - url = 'config' - - result = self._rest_client.get(url) - return ServerConfig.from_dict(result) - - def clear_latest_caches(self): - self._latest_version_cache.clear() - self._latest_with_metadata_cache.clear() - - def clear_caches(self): - self._latest_version_cache.clear() - self._latest_with_metadata_cache.clear() - self._cache.clear() - - @staticmethod - def new_client(conf: dict) -> 'SchemaRegistryClient': - from .mock_schema_registry_client import MockSchemaRegistryClient - url = conf.get("url") - if url.startswith("mock://"): - return MockSchemaRegistryClient(conf) - return SchemaRegistryClient(conf) - - -T = TypeVar("T") - - -class RuleKind(str, Enum): - CONDITION = "CONDITION" - TRANSFORM = "TRANSFORM" - - def __str__(self) -> str: - return str(self.value) - - -class RuleMode(str, Enum): - UPGRADE = "UPGRADE" - DOWNGRADE = "DOWNGRADE" - UPDOWN = "UPDOWN" - READ = "READ" - WRITE = "WRITE" - WRITEREAD = "WRITEREAD" - - def __str__(self) -> str: - return str(self.value) - - -@_attrs_define -class RuleParams: - params: Dict[str, str] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - field_dict.update(self.params) - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - rule_params = cls(params=d) - - return rule_params - - def __hash__(self): - return hash(frozenset(self.params.items())) - - -@_attrs_define(frozen=True) -class Rule: - name: Optional[str] - doc: Optional[str] - kind: Optional[RuleKind] - mode: Optional[RuleMode] - type: Optional[str] - tags: Optional[List[str]] = _attrs_field(hash=False) - params: Optional[RuleParams] - expr: Optional[str] - on_success: Optional[str] - on_failure: Optional[str] - disabled: Optional[bool] - - def to_dict(self) -> Dict[str, Any]: - name = self.name - - doc = self.doc - - kind_str: Optional[str] = None - if self.kind is not None: - kind_str = self.kind.value - - mode_str: Optional[str] = None - if self.mode is not None: - mode_str = self.mode.value - - rule_type = self.type - - tags = self.tags - - _params: Optional[Dict[str, Any]] = None - if self.params is not None: - _params = self.params.to_dict() - - expr = self.expr - - on_success = self.on_success - - on_failure = self.on_failure - - disabled = self.disabled - - field_dict: Dict[str, Any] = {} - field_dict.update({}) - if name is not None: - field_dict["name"] = name - if doc is not None: - field_dict["doc"] = doc - if kind_str is not None: - field_dict["kind"] = kind_str - if mode_str is not None: - field_dict["mode"] = mode_str - if type is not None: - field_dict["type"] = rule_type - if tags is not None: - field_dict["tags"] = tags - if _params is not None: - field_dict["params"] = _params - if expr is not None: - field_dict["expr"] = expr - if on_success is not None: - field_dict["onSuccess"] = on_success - if on_failure is not None: - field_dict["onFailure"] = on_failure - if disabled is not None: - field_dict["disabled"] = disabled - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - name = d.pop("name", None) - - doc = d.pop("doc", None) - - _kind = d.pop("kind", None) - kind: Optional[RuleKind] = None - if _kind is not None: - kind = RuleKind(_kind) - - _mode = d.pop("mode", None) - mode: Optional[RuleMode] = None - if _mode is not None: - mode = RuleMode(_mode) - - rule_type = d.pop("type", None) - - tags = cast(List[str], d.pop("tags", None)) - - _params: Optional[Dict[str, Any]] = d.pop("params", None) - params: Optional[RuleParams] = None - if _params is not None: - params = RuleParams.from_dict(_params) - - expr = d.pop("expr", None) - - on_success = d.pop("onSuccess", None) - - on_failure = d.pop("onFailure", None) - - disabled = d.pop("disabled", None) - - rule = cls( - name=name, - doc=doc, - kind=kind, - mode=mode, - type=rule_type, - tags=tags, - params=params, - expr=expr, - on_success=on_success, - on_failure=on_failure, - disabled=disabled, - ) - - return rule - - -@_attrs_define -class RuleSet: - migration_rules: Optional[List["Rule"]] = _attrs_field(hash=False) - domain_rules: Optional[List["Rule"]] = _attrs_field(hash=False) - - def to_dict(self) -> Dict[str, Any]: - _migration_rules: Optional[List[Dict[str, Any]]] = None - if self.migration_rules is not None: - _migration_rules = [] - for migration_rules_item_data in self.migration_rules: - migration_rules_item = migration_rules_item_data.to_dict() - _migration_rules.append(migration_rules_item) - - _domain_rules: Optional[List[Dict[str, Any]]] = None - if self.domain_rules is not None: - _domain_rules = [] - for domain_rules_item_data in self.domain_rules: - domain_rules_item = domain_rules_item_data.to_dict() - _domain_rules.append(domain_rules_item) - - field_dict: Dict[str, Any] = {} - field_dict.update({}) - if _migration_rules is not None: - field_dict["migrationRules"] = _migration_rules - if _domain_rules is not None: - field_dict["domainRules"] = _domain_rules - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - migration_rules = [] - _migration_rules = d.pop("migrationRules", None) - for migration_rules_item_data in _migration_rules or []: - migration_rules_item = Rule.from_dict(migration_rules_item_data) - migration_rules.append(migration_rules_item) - - domain_rules = [] - _domain_rules = d.pop("domainRules", None) - for domain_rules_item_data in _domain_rules or []: - domain_rules_item = Rule.from_dict(domain_rules_item_data) - domain_rules.append(domain_rules_item) - - rule_set = cls( - migration_rules=migration_rules, - domain_rules=domain_rules, - ) - - return rule_set - - def __hash__(self): - return hash(frozenset((self.migration_rules or []) + (self.domain_rules or []))) - - -@_attrs_define -class MetadataTags: - tags: Dict[str, List[str]] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - for prop_name, prop in self.tags.items(): - field_dict[prop_name] = prop - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - tags = {} - for prop_name, prop_dict in d.items(): - tag = cast(List[str], prop_dict) - - tags[prop_name] = tag - - metadata_tags = cls(tags=tags) - - return metadata_tags - - def __hash__(self): - return hash(frozenset(self.tags.items())) - - -@_attrs_define -class MetadataProperties: - properties: Dict[str, str] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - field_dict.update(self.properties) - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - metadata_properties = cls(properties=d) - - return metadata_properties - - def __hash__(self): - return hash(frozenset(self.properties.items())) - - -@_attrs_define(frozen=True) -class Metadata: - tags: Optional[MetadataTags] - properties: Optional[MetadataProperties] - sensitive: Optional[List[str]] = _attrs_field(hash=False) - - def to_dict(self) -> Dict[str, Any]: - _tags: Optional[Dict[str, Any]] = None - if self.tags is not None: - _tags = self.tags.to_dict() - - _properties: Optional[Dict[str, Any]] = None - if self.properties is not None: - _properties = self.properties.to_dict() - - sensitive: Optional[List[str]] = None - if self.sensitive is not None: - sensitive = [] - for sensitive_item in self.sensitive: - sensitive.append(sensitive_item) - - field_dict: Dict[str, Any] = {} - if _tags is not None: - field_dict["tags"] = _tags - if _properties is not None: - field_dict["properties"] = _properties - if sensitive is not None: - field_dict["sensitive"] = sensitive - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - _tags: Optional[Dict[str, Any]] = d.pop("tags", None) - tags: Optional[MetadataTags] = None - if _tags is not None: - tags = MetadataTags.from_dict(_tags) - - _properties: Optional[Dict[str, Any]] = d.pop("properties", None) - properties: Optional[MetadataProperties] = None - if _properties is not None: - properties = MetadataProperties.from_dict(_properties) - - sensitive = [] - _sensitive = d.pop("sensitive", None) - for sensitive_item in _sensitive or []: - sensitive.append(sensitive_item) - - metadata = cls( - tags=tags, - properties=properties, - sensitive=sensitive, - ) - - return metadata - - -@_attrs_define(frozen=True) -class SchemaReference: - name: Optional[str] - subject: Optional[str] - version: Optional[int] - - def to_dict(self) -> Dict[str, Any]: - name = self.name - - subject = self.subject - - version = self.version - - field_dict: Dict[str, Any] = {} - if name is not None: - field_dict["name"] = name - if subject is not None: - field_dict["subject"] = subject - if version is not None: - field_dict["version"] = version - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - name = d.pop("name", None) - - subject = d.pop("subject", None) - - version = d.pop("version", None) - - schema_reference = cls( - name=name, - subject=subject, - version=version, - ) - - return schema_reference - - -class ConfigCompatibilityLevel(str, Enum): - BACKWARD = "BACKWARD" - BACKWARD_TRANSITIVE = "BACKWARD_TRANSITIVE" - FORWARD = "FORWARD" - FORWARD_TRANSITIVE = "FORWARD_TRANSITIVE" - FULL = "FULL" - FULL_TRANSITIVE = "FULL_TRANSITIVE" - NONE = "NONE" - - def __str__(self) -> str: - return str(self.value) - - -@_attrs_define -class ServerConfig: - compatibility: Optional[ConfigCompatibilityLevel] = None - compatibility_level: Optional[ConfigCompatibilityLevel] = None - compatibility_group: Optional[str] = None - default_metadata: Optional[Metadata] = None - override_metadata: Optional[Metadata] = None - default_rule_set: Optional[RuleSet] = None - override_rule_set: Optional[RuleSet] = None - - def to_dict(self) -> Dict[str, Any]: - _compatibility: Optional[str] = None - if self.compatibility is not None: - _compatibility = self.compatibility.value - - _compatibility_level: Optional[str] = None - if self.compatibility_level is not None: - _compatibility_level = self.compatibility_level.value - - compatibility_group = self.compatibility_group - - _default_metadata: Optional[Dict[str, Any]] - if isinstance(self.default_metadata, Metadata): - _default_metadata = self.default_metadata.to_dict() - else: - _default_metadata = self.default_metadata - - _override_metadata: Optional[Dict[str, Any]] - if isinstance(self.override_metadata, Metadata): - _override_metadata = self.override_metadata.to_dict() - else: - _override_metadata = self.override_metadata - - _default_rule_set: Optional[Dict[str, Any]] - if isinstance(self.default_rule_set, RuleSet): - _default_rule_set = self.default_rule_set.to_dict() - else: - _default_rule_set = self.default_rule_set - - _override_rule_set: Optional[Dict[str, Any]] - if isinstance(self.override_rule_set, RuleSet): - _override_rule_set = self.override_rule_set.to_dict() - else: - _override_rule_set = self.override_rule_set - - field_dict: Dict[str, Any] = {} - if _compatibility is not None: - field_dict["compatibility"] = _compatibility - if _compatibility_level is not None: - field_dict["compatibilityLevel"] = _compatibility_level - if compatibility_group is not None: - field_dict["compatibilityGroup"] = compatibility_group - if _default_metadata is not None: - field_dict["defaultMetadata"] = _default_metadata - if _override_metadata is not None: - field_dict["overrideMetadata"] = _override_metadata - if _default_rule_set is not None: - field_dict["defaultRuleSet"] = _default_rule_set - if _override_rule_set is not None: - field_dict["overrideRuleSet"] = _override_rule_set - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - _compatibility = d.pop("compatibility", None) - compatibility: Optional[ConfigCompatibilityLevel] - if _compatibility is None: - compatibility = None - else: - compatibility = ConfigCompatibilityLevel(_compatibility) - - _compatibility_level = d.pop("compatibilityLevel", None) - compatibility_level: Optional[ConfigCompatibilityLevel] - if _compatibility_level is None: - compatibility_level = None - else: - compatibility_level = ConfigCompatibilityLevel(_compatibility_level) - - compatibility_group = d.pop("compatibilityGroup", None) - - def _parse_default_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - default_metadata = _parse_default_metadata(d.pop("defaultMetadata", None)) - - def _parse_override_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - override_metadata = _parse_override_metadata(d.pop("overrideMetadata", None)) - - def _parse_default_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - default_rule_set = _parse_default_rule_set(d.pop("defaultRuleSet", None)) - - def _parse_override_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - override_rule_set = _parse_override_rule_set(d.pop("overrideRuleSet", None)) - - config = cls( - compatibility=compatibility, - compatibility_level=compatibility_level, - compatibility_group=compatibility_group, - default_metadata=default_metadata, - override_metadata=override_metadata, - default_rule_set=default_rule_set, - override_rule_set=override_rule_set, - ) - - return config - - -@_attrs_define(frozen=True, cache_hash=True) -class Schema: - """ - An unregistered schema. - """ - - schema_str: Optional[str] - schema_type: Optional[str] = "AVRO" - references: Optional[List[SchemaReference]] = _attrs_field(factory=list, hash=False) - metadata: Optional[Metadata] = None - rule_set: Optional[RuleSet] = None - - def to_dict(self) -> Dict[str, Any]: - schema = self.schema_str - - schema_type = self.schema_type - - _references: Optional[List[Dict[str, Any]]] = [] - if self.references is not None: - for references_item_data in self.references: - references_item = references_item_data.to_dict() - _references.append(references_item) - - _metadata: Optional[Dict[str, Any]] = None - if isinstance(self.metadata, Metadata): - _metadata = self.metadata.to_dict() - - _rule_set: Optional[Dict[str, Any]] = None - if isinstance(self.rule_set, RuleSet): - _rule_set = self.rule_set.to_dict() - - field_dict: Dict[str, Any] = {} - if schema is not None: - field_dict["schema"] = schema - if schema_type is not None: - field_dict["schemaType"] = schema_type - if _references is not None: - field_dict["references"] = _references - if _metadata is not None: - field_dict["metadata"] = _metadata - if _rule_set is not None: - field_dict["ruleSet"] = _rule_set - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - schema = d.pop("schema", None) - - schema_type = d.pop("schemaType", "AVRO") - - references = [] - _references = d.pop("references", None) - for references_item_data in _references or []: - references_item = SchemaReference.from_dict(references_item_data) - - references.append(references_item) - - def _parse_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - metadata = _parse_metadata(d.pop("metadata", None)) - - def _parse_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - rule_set = _parse_rule_set(d.pop("ruleSet", None)) - - schema = cls( - schema_str=schema, - schema_type=schema_type, - references=references, - metadata=metadata, - rule_set=rule_set, - ) - - return schema - - -@_attrs_define(frozen=True, cache_hash=True) -class RegisteredSchema: - """ - An registered schema. - """ - - schema_id: Optional[int] - schema: Optional[Schema] - subject: Optional[str] - version: Optional[int] - - def to_dict(self) -> Dict[str, Any]: - schema = self.schema - - schema_id = self.schema_id - - subject = self.subject - - version = self.version - - field_dict: Dict[str, Any] = {} - if schema is not None: - field_dict = schema.to_dict() - if schema_id is not None: - field_dict["id"] = schema_id - if subject is not None: - field_dict["subject"] = subject - if version is not None: - field_dict["version"] = version - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - schema = Schema.from_dict(d) - - schema_id = d.pop("id", None) - - subject = d.pop("subject", None) - - version = d.pop("version", None) - schema = cls( - schema_id=schema_id, - schema=schema, - subject=subject, - version=version, - ) +from .common.schema_registry_client import * +from ._sync.schema_registry_client import * - return schema +from .error import SchemaRegistryError diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 1cfb384e1..3f351349c 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -16,524 +16,5 @@ # limitations under the License. # -__all__ = ['BaseSerializer', - 'BaseDeserializer', - 'FieldContext', - 'FieldRuleExecutor', - 'FieldTransform', - 'FieldTransformer', - 'FieldType', - 'ParsedSchemaCache', - 'RuleAction', - 'RuleContext', - 'RuleConditionError', - 'RuleError', - 'RuleExecutor'] - -import abc -import logging -from enum import Enum -from threading import Lock -from typing import Callable, List, Optional, Set, Dict, Any, TypeVar - -from confluent_kafka.schema_registry import RegisteredSchema -from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ - Rule, RuleKind, Schema, RuleSet -from confluent_kafka.schema_registry.wildcard_matcher import wildcard_match -from confluent_kafka.serialization import Serializer, Deserializer, \ - SerializationContext, SerializationError - - -log = logging.getLogger(__name__) - - -class FieldType(str, Enum): - RECORD = "RECORD" - ENUM = "ENUM" - ARRAY = "ARRAY" - MAP = "MAP" - COMBINED = "COMBINED" - FIXED = "FIXED" - STRING = "STRING" - BYTES = "BYTES" - INT = "INT" - LONG = "LONG" - FLOAT = "FLOAT" - DOUBLE = "DOUBLE" - BOOLEAN = "BOOLEAN" - NULL = "NULL" - - -class FieldContext(object): - __slots__ = ['containing_message', 'full_name', 'name', 'field_type', 'tags'] - - def __init__( - self, containing_message: Any, full_name: str, name: str, - field_type: FieldType, tags: Set[str] - ): - self.containing_message = containing_message - self.full_name = full_name - self.name = name - self.field_type = field_type - self.tags = tags - - def is_primitive(self) -> bool: - return self.field_type in (FieldType.INT, FieldType.LONG, FieldType.FLOAT, - FieldType.DOUBLE, FieldType.BOOLEAN, FieldType.NULL, - FieldType.STRING, FieldType.BYTES) - - def type_name(self) -> str: - return self.field_type.name - - -class RuleContext(object): - __slots__ = ['ser_ctx', 'source', 'target', 'subject', 'rule_mode', 'rule', - 'index', 'rules', 'inline_tags', 'field_transformer', '_field_contexts'] - - def __init__( - self, ser_ctx: SerializationContext, source: Optional[Schema], - target: Optional[Schema], subject: str, rule_mode: RuleMode, rule: Rule, - index: int, rules: List[Rule], inline_tags: Optional[Dict[str, Set[str]]], field_transformer - ): - self.ser_ctx = ser_ctx - self.source = source - self.target = target - self.subject = subject - self.rule_mode = rule_mode - self.rule = rule - self.index = index - self.rules = rules - self.inline_tags = inline_tags - self.field_transformer = field_transformer - self._field_contexts: List[FieldContext] = [] - - def get_parameter(self, name: str) -> Optional[str]: - params = self.rule.params - if params is not None: - value = params.params.get(name) - if value is not None: - return value - if (self.target is not None - and self.target.metadata is not None - and self.target.metadata.properties is not None): - value = self.target.metadata.properties.properties.get(name) - if value is not None: - return value - return None - - def _get_inline_tags(self, name: str) -> Set[str]: - if self.inline_tags is None: - return set() - return self.inline_tags.get(name, set()) - - def current_field(self) -> Optional[FieldContext]: - if not self._field_contexts: - return None - return self._field_contexts[-1] - - def enter_field( - self, containing_message: Any, full_name: str, name: str, - field_type: FieldType, tags: Optional[Set[str]] - ) -> FieldContext: - all_tags = set(tags if tags is not None else self._get_inline_tags(full_name)) - all_tags.update(self.get_tags(full_name)) - field_context = FieldContext(containing_message, full_name, name, field_type, all_tags) - self._field_contexts.append(field_context) - return field_context - - def get_tags(self, full_name: str) -> Set[str]: - result = set() - if (self.target is not None - and self.target.metadata is not None - and self.target.metadata.tags is not None): - tags = self.target.metadata.tags.tags - for k, v in tags.items(): - if wildcard_match(full_name, k): - result.update(v) - return result - - def exit_field(self): - if self._field_contexts: - self._field_contexts.pop() - - -FieldTransform = Callable[[RuleContext, FieldContext, Any], Any] - - -FieldTransformer = Callable[[RuleContext, FieldTransform, Any], Any] - - -class RuleBase(metaclass=abc.ABCMeta): - def configure(self, client_conf: dict, rule_conf: dict): - pass - - @abc.abstractmethod - def type(self) -> str: - raise NotImplementedError() - - def close(self): - pass - - -class RuleExecutor(RuleBase): - @abc.abstractmethod - def transform(self, ctx: RuleContext, message: Any) -> Any: - raise NotImplementedError() - - -class FieldRuleExecutor(RuleExecutor): - @abc.abstractmethod - def new_transform(self, ctx: RuleContext) -> FieldTransform: - raise NotImplementedError() - - def transform(self, ctx: RuleContext, message: Any) -> Any: - # TODO preserve source - if ctx.rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): - for i in range(ctx.index): - other_rule = ctx.rules[i] - if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): - # ignore this transform if an earlier one has the same tag - return message - elif ctx.rule_mode == RuleMode.READ or ctx.rule_mode == RuleMode.DOWNGRADE: - for i in range(ctx.index + 1, len(ctx.rules)): - other_rule = ctx.rules[i] - if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): - # ignore this transform if a later one has the same tag - return message - return ctx.field_transformer(ctx, self.new_transform(ctx), message) - - @staticmethod - def are_transforms_with_same_tag(rule1: Rule, rule2: Rule) -> bool: - return (bool(rule1.tags) - and rule1.kind == RuleKind.TRANSFORM - and rule1.kind == rule2.kind - and rule1.mode == rule2.mode - and rule1.type == rule2.type - and rule1.tags == rule2.tags) - - -class RuleAction(RuleBase): - @abc.abstractmethod - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - raise NotImplementedError() - - -class ErrorAction(RuleAction): - def type(self) -> str: - return 'ERROR' - - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - if ex is None: - raise SerializationError() - else: - raise SerializationError() from ex - - -class NoneAction(RuleAction): - def type(self) -> str: - return 'NONE' - - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - pass - - -class RuleError(Exception): - pass - - -class RuleConditionError(RuleError): - def __init__(self, rule: Rule): - super().__init__(RuleConditionError.error_message(rule)) - - @staticmethod - def error_message(rule: Rule) -> str: - if rule.doc: - return rule.doc - elif rule.expr: - return f"Rule expr failed: {rule.expr}" - else: - return f"Rule failed: {rule.name}" - - -class Migration(object): - __slots__ = ['rule_mode', 'source', 'target'] - - def __init__( - self, rule_mode: RuleMode, source: Optional[RegisteredSchema], - target: Optional[RegisteredSchema] - ): - self.rule_mode = rule_mode - self.source = source - self.target = target - - -class BaseSerde(object): - __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', - '_registry', '_rule_registry', '_subject_name_func', - '_field_transformer'] - - def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: - if self._use_schema_id is not None: - schema = self._registry.get_schema(self._use_schema_id, subject, fmt) - return self._registry.lookup_schema(subject, schema, False, True) - if self._use_latest_with_metadata is not None: - return self._registry.get_latest_with_metadata( - subject, self._use_latest_with_metadata, True, fmt) - if self._use_latest_version: - return self._registry.get_latest_version(subject, fmt) - return None - - def _execute_rules( - self, ser_ctx: SerializationContext, subject: str, - rule_mode: RuleMode, - source: Optional[Schema], target: Optional[Schema], - message: Any, inline_tags: Optional[Dict[str, Set[str]]], - field_transformer: Optional[FieldTransformer] - ) -> Any: - if message is None or target is None: - return message - rules: Optional[List[Rule]] = None - if rule_mode == RuleMode.UPGRADE: - if target is not None and target.rule_set is not None: - rules = target.rule_set.migration_rules - elif rule_mode == RuleMode.DOWNGRADE: - if source is not None and source.rule_set is not None: - rules = source.rule_set.migration_rules - rules = rules[:] if rules else [] - rules.reverse() - else: - if target is not None and target.rule_set is not None: - rules = target.rule_set.domain_rules - if rule_mode == RuleMode.READ: - # Execute read rules in reverse order for symmetry - rules = rules[:] if rules else [] - rules.reverse() - - if not rules: - return message - - for index in range(len(rules)): - rule = rules[index] - if self._is_disabled(rule): - continue - if rule.mode == RuleMode.WRITEREAD: - if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: - continue - elif rule.mode == RuleMode.UPDOWN: - if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: - continue - elif rule.mode != rule_mode: - continue - - ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, - index, rules, inline_tags, field_transformer) - rule_executor = self._rule_registry.get_executor(rule.type.upper()) - if rule_executor is None: - self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, - RuleError(f"Could not find rule executor of type {rule.type}"), - 'ERROR') - return message - try: - result = rule_executor.transform(ctx, message) - if rule.kind == RuleKind.CONDITION: - if not result: - raise RuleConditionError(rule) - elif rule.kind == RuleKind.TRANSFORM: - message = result - self._run_action( - ctx, rule_mode, rule, - self._get_on_failure(rule) if message is None else self._get_on_success(rule), - message, None, - 'ERROR' if message is None else 'NONE') - except SerializationError: - raise - except Exception as e: - self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), - message, e, 'ERROR') - return message - - def _get_on_success(self, rule: Rule) -> Optional[str]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.on_success is not None: - return override.on_success - return rule.on_success - - def _get_on_failure(self, rule: Rule) -> Optional[str]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.on_failure is not None: - return override.on_failure - return rule.on_failure - - def _is_disabled(self, rule: Rule) -> Optional[bool]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.disabled is not None: - return override.disabled - return rule.disabled - - def _run_action( - self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, - action: Optional[str], message: Any, - ex: Optional[Exception], default_action: str - ): - action_name = self._get_rule_action_name(rule, rule_mode, action) - if action_name is None: - action_name = default_action - rule_action = self._get_rule_action(ctx, action_name) - if rule_action is None: - log.error("Could not find rule action of type %s", action_name) - raise RuleError(f"Could not find rule action of type {action_name}") - try: - rule_action.run(ctx, message, ex) - except SerializationError: - raise - except Exception as e: - log.warning("Could not run post-rule action %s: %s", action_name, e) - - def _get_rule_action_name( - self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] - ) -> Optional[str]: - if action_name is None or action_name == "": - return None - if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: - parts = action_name.split(',') - if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): - return parts[0] - elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): - return parts[1] - return action_name - - def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: - if action_name == 'ERROR': - return ErrorAction() - elif action_name == 'NONE': - return NoneAction() - return self._rule_registry.get_action(action_name) - - -class BaseSerializer(BaseSerde, Serializer): - __slots__ = ['_auto_register', '_normalize_schemas'] - - -class BaseDeserializer(BaseSerde, Deserializer): - __slots__ = [] - - def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: - if rule_set is None: - return False - if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): - return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN - for rule in rule_set.migration_rules or []) - elif mode == RuleMode.UPDOWN: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) - elif mode in (RuleMode.WRITE, RuleMode.READ): - return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD - for rule in rule_set.domain_rules or []) - elif mode == RuleMode.WRITEREAD: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) - return False - - def _get_migrations( - self, subject: str, source_info: Schema, - target: RegisteredSchema, fmt: Optional[str] - ) -> List[Migration]: - source = self._registry.lookup_schema(subject, source_info, False, True) - migrations = [] - if source.version < target.version: - migration_mode = RuleMode.UPGRADE - first = source - last = target - elif source.version > target.version: - migration_mode = RuleMode.DOWNGRADE - first = target - last = source - else: - return migrations - previous: Optional[RegisteredSchema] = None - versions = self._get_schemas_between(subject, first, last, fmt) - for i in range(len(versions)): - version = versions[i] - if i == 0: - previous = version - continue - if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) - previous = version - if migration_mode == RuleMode.DOWNGRADE: - migrations.reverse() - return migrations - - def _get_schemas_between( - self, subject: str, first: RegisteredSchema, - last: RegisteredSchema, fmt: Optional[str] = None - ) -> List[RegisteredSchema]: - if last.version - first.version <= 1: - return [first, last] - version1 = first.version - version2 = last.version - result = [first] - for i in range(version1 + 1, version2): - result.append(self._registry.get_version(subject, i, True, fmt)) - result.append(last) - return result - - def _execute_migrations( - self, ser_ctx: SerializationContext, subject: str, - migrations: List[Migration], message: Any - ) -> Any: - for migration in migrations: - message = self._execute_rules(ser_ctx, subject, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) - return message - - -T = TypeVar("T") - - -class ParsedSchemaCache(object): - """ - Thread-safe cache for parsed schemas - """ - - def __init__(self): - self.lock = Lock() - self.parsed_schemas = {} - - def set(self, schema: Schema, parsed_schema: T): - """ - Add a Schema identified by schema_id to the cache. - - Args: - schema (Schema): The schema - - parsed_schema (Any): The parsed schema - """ - - with self.lock: - self.parsed_schemas[schema] = parsed_schema - - def get_parsed_schema(self, schema: Schema) -> Optional[T]: - """ - Get the parsed schema associated with the schema - - Args: - schema (Schema): The schema - - Returns: - The parsed schema if known; else None - """ - - with self.lock: - return self.parsed_schemas.get(schema, None) - - def clear(self): - """ - Clear the cache. - """ - - with self.lock: - self.parsed_schemas.clear() +from .common.serde import * +from ._sync.serde import * diff --git a/tests/integration/schema_registry/__init__.py b/tests/integration/schema_registry/_sync/__init__.py similarity index 100% rename from tests/integration/schema_registry/__init__.py rename to tests/integration/schema_registry/_sync/__init__.py diff --git a/tests/integration/schema_registry/test_api_client.py b/tests/integration/schema_registry/_sync/test_api_client.py similarity index 100% rename from tests/integration/schema_registry/test_api_client.py rename to tests/integration/schema_registry/_sync/test_api_client.py diff --git a/tests/integration/schema_registry/test_avro_serializers.py b/tests/integration/schema_registry/_sync/test_avro_serializers.py similarity index 99% rename from tests/integration/schema_registry/test_avro_serializers.py rename to tests/integration/schema_registry/_sync/test_avro_serializers.py index 4140ad600..322637ae7 100644 --- a/tests/integration/schema_registry/test_avro_serializers.py +++ b/tests/integration/schema_registry/_sync/test_avro_serializers.py @@ -179,9 +179,11 @@ def _references_test_common(kafka_cluster, awarded_user, serializer_schema, dese producer = kafka_cluster.producer(value_serializer=value_serializer) producer.produce(topic, value=awarded_user, partition=0) + producer.flush() consumer = kafka_cluster.consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) msg = consumer.poll() diff --git a/tests/integration/schema_registry/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py similarity index 99% rename from tests/integration/schema_registry/test_json_serializers.py rename to tests/integration/schema_registry/_sync/test_json_serializers.py index 5b6700438..ae67c30f2 100644 --- a/tests/integration/schema_registry/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import SchemaReference, Schema +from confluent_kafka.schema_registry import SchemaReference, Schema, SchemaRegistryClient from confluent_kafka.schema_registry.json_schema import (JSONSerializer, JSONDeserializer) @@ -404,7 +404,7 @@ def test_json_record_deserialization_mismatch(kafka_cluster, load_file): consumer.poll() -def _register_referenced_schemas(sr, load_file): +def _register_referenced_schemas(sr: SchemaRegistryClient, load_file): sr.register_schema("product", Schema(load_file("product.json"), 'JSON')) sr.register_schema("customer", Schema(load_file("customer.json"), 'JSON')) sr.register_schema("order_details", Schema(load_file("order_details.json"), 'JSON', [ diff --git a/tests/integration/schema_registry/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py similarity index 96% rename from tests/integration/schema_registry/test_proto_serializers.py rename to tests/integration/schema_registry/_sync/test_proto_serializers.py index 16de4ea6b..54e458152 100644 --- a/tests/integration/schema_registry/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition, KafkaException, KafkaError from confluent_kafka.error import ConsumeError from confluent_kafka.schema_registry.protobuf import ProtobufSerializer, ProtobufDeserializer -from .data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ +from tests.integration.schema_registry.data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ PublicTestProto_pb2 from tests.integration.schema_registry.data.proto.DependencyTestProto_pb2 import DependencyMessage from tests.integration.schema_registry.data.proto.exampleProtoCriteo_pb2 import ClickCas @@ -92,7 +92,7 @@ def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = sr.get_schema(serializer._schema_id).references + registered_refs = (sr.get_schema(serializer._schema_id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() From bd445fd19809ddd4b77cf404d702ce1a7753d57c Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 17 Apr 2025 14:56:08 -0700 Subject: [PATCH 02/32] remove avro module import --- src/confluent_kafka/schema_registry/avro.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 103acd58c..e570f18f8 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -16,5 +16,4 @@ # limitations under the License. from .common.avro import * -from ..avro import * from ._sync.avro import * From a5fbc78a68d165d25781031b78ce1a0080c62033 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Mon, 21 Apr 2025 11:02:34 -0700 Subject: [PATCH 03/32] Add top level __all__ to refactored modules --- .../schema_registry/__init__.py | 5 +-- .../schema_registry/_sync/avro.py | 6 +++- .../schema_registry/_sync/json_schema.py | 5 +++ .../schema_registry/_sync/protobuf.py | 5 +++ .../_sync/schema_registry_client.py | 21 ++++++------ .../schema_registry/_sync/serde.py | 6 ++++ .../schema_registry/common/avro.py | 15 ++++++++ .../schema_registry/common/json_schema.py | 13 +++++++ .../schema_registry/common/protobuf.py | 19 +++++++++++ .../common/schema_registry_client.py | 34 +++++++++++++++++++ .../schema_registry/common/serde.py | 18 ++++++++++ .../test_bearer_field_provider.py | 4 +-- 12 files changed, 135 insertions(+), 16 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index d405f7433..97a869b14 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Optional from .schema_registry_client import ( ConfigCompatibilityLevel, @@ -44,6 +43,7 @@ ) __all__ = [ + "_MAGIC_BYTE", "ConfigCompatibilityLevel", "Metadata", "MetadataProperties", @@ -61,5 +61,6 @@ "ServerConfig", "topic_subject_name_strategy", "topic_record_subject_name_strategy", - "record_subject_name_strategy" + "record_subject_name_strategy", + "reference_subject_name_strategy", ] diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index a7cce69b2..6c032f7b1 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -34,6 +34,11 @@ from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache +__all__ = [ + '_resolve_named_schema', + 'AvroSerializer', + 'AvroDeserializer', +] def _resolve_named_schema( schema: Schema, schema_registry_client: SchemaRegistryClient @@ -344,7 +349,6 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema - class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index ac2513913..819618188 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -40,6 +40,11 @@ from confluent_kafka.serialization import (SerializationError, SerializationContext) +__all__ = [ + '_resolve_named_schema', + 'JSONSerializer', + 'JSONDeserializer' +] def _resolve_named_schema( schema: Schema, schema_registry_client: SchemaRegistryClient, diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 7a263d124..e29401063 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -40,6 +40,11 @@ from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache +__all__ = [ + '_resolve_named_schema', + 'ProtobufSerializer', + 'ProtobufDeserializer', +] def _resolve_named_schema( schema: Schema, diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 1f86dbf1c..471ff0857 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -40,8 +40,18 @@ full_jitter, _SchemaCache, Schema, + _StaticFieldProvider, ) +__all__ = [ + '_urlencode', + '_CustomOAuthClient', + '_OAuthClient', + '_BaseRestClient', + '_RestClient', + 'SchemaRegistryClient', +] + # TODO: consider adding `six` dependency or employing a compat file # Python 2.7 is officially EOL so compatibility issue will be come more the norm. # We need a better way to handle these issues. @@ -64,17 +74,6 @@ def _urlencode(value: str) -> str: log = logging.getLogger(__name__) -class _StaticFieldProvider(_BearerFieldProvider): - def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - class _CustomOAuthClient(_BearerFieldProvider): def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): self.custom_function = custom_function diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index 409e00160..d21fcaa3c 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -26,6 +26,12 @@ from confluent_kafka.serialization import Serializer, Deserializer, \ SerializationContext, SerializationError +__all__ = [ + 'BaseSerde', + 'BaseSerializer', + 'BaseDeserializer', +] + log = logging.getLogger(__name__) class BaseSerde(object): diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index a36e096c7..602cf933b 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -12,6 +12,21 @@ from confluent_kafka.schema_registry.serde import RuleContext, FieldType, \ FieldTransform, RuleConditionError +__all__ = [ + 'AvroMessage', + 'AvroSchema', + '_schema_loads', + 'LocalSchemaRepository', + 'parse_schema_with_repo', + 'transform', + '_transform_field', + 'get_type', + '_disjoint', + '_resolve_union', + 'get_inline_tags', + '_get_inline_tags_recursively', + '_implied_namespace', +] AvroMessage = Union[ None, # 'null' Avro type diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py index 2a147c09b..902df18a9 100644 --- a/src/confluent_kafka/schema_registry/common/json_schema.py +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -14,6 +14,19 @@ from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, FieldType, \ RuleConditionError +__all__ = [ + 'JsonMessage', + 'JsonSchema', + 'DEFAULT_SPEC', + '_retrieve_via_httpx', + 'transform', + '_transform_field', + '_validate_subschemas', + 'get_type', + '_disjoint', + 'get_inline_tags', +] + JsonMessage = Union[ None, # 'null' Avro type str, # 'string' and 'enum' diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index bb8c9b6d4..f805ae9c6 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -25,6 +25,25 @@ from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, \ FieldType, RuleConditionError +__all__ = [ + '_bytes', + '_create_index_array', + '_schema_to_str', + '_proto_to_str', + '_str_to_proto', + '_init_pool', + 'transform', + '_transform_field', + '_set_field', + 'get_type', + 'is_map_field', + 'get_inline_tags', + '_disjoint', + '_is_builtin', + 'decimalToProtobuf', + 'protobufToDecimal' +] + # Convert an int to bytes (inverse of ord()) # Python3.chr() -> Unicode # Python2.chr() -> str(alias for bytes) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index ca8efbe18..812929ca0 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -25,6 +25,29 @@ from typing import List, Dict, Type, TypeVar, \ cast, Optional, Any +__all__ = [ + 'VALID_AUTH_PROVIDERS', + '_BearerFieldProvider', + 'is_success', + 'is_retriable', + 'full_jitter', + '_StaticFieldProvider', + '_SchemaCache', + 'RuleKind', + 'RuleMode', + 'RuleParams', + 'Rule', + 'RuleSet', + 'MetadataTags', + 'MetadataProperties', + 'Metadata', + 'SchemaReference', + 'ConfigCompatibilityLevel', + 'ServerConfig', + 'Schema', + 'RegisteredSchema' +] + VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] class _BearerFieldProvider(metaclass=abc.ABCMeta): @@ -46,6 +69,17 @@ def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) - return random.random() * min(no_jitter_delay, max_delay_ms) +class _StaticFieldProvider(_BearerFieldProvider): + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} + + class _SchemaCache(object): """ Thread-safe cache for use with the Schema Registry Client. diff --git a/src/confluent_kafka/schema_registry/common/serde.py b/src/confluent_kafka/schema_registry/common/serde.py index 864adca0c..82f9d23e6 100644 --- a/src/confluent_kafka/schema_registry/common/serde.py +++ b/src/confluent_kafka/schema_registry/common/serde.py @@ -29,6 +29,24 @@ from confluent_kafka.serialization import SerializationContext, SerializationError +__all__ = [ + 'FieldType', + 'FieldContext', + 'RuleContext', + 'FieldTransform', + 'FieldTransformer', + 'RuleBase', + 'RuleExecutor', + 'FieldRuleExecutor', + 'RuleAction', + 'ErrorAction', + 'NoneAction', + 'RuleError', + 'RuleConditionError', + 'Migration', + 'ParsedSchemaCache', +] + log = logging.getLogger(__name__) diff --git a/tests/schema_registry/test_bearer_field_provider.py b/tests/schema_registry/test_bearer_field_provider.py index d67804a12..a6dfc8eb0 100644 --- a/tests/schema_registry/test_bearer_field_provider.py +++ b/tests/schema_registry/test_bearer_field_provider.py @@ -77,8 +77,8 @@ def update_token2(): def test_generate_token_retry_logic(): oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) - with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep, - patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter): + with (patch("confluent_kafka.schema_registry._sync.schema_registry_client.time.sleep") as mock_sleep, + patch("confluent_kafka.schema_registry._sync.schema_registry_client.full_jitter") as mock_jitter): with pytest.raises(OAuthTokenError): oauth_client.generate_access_token() From 1e6404d5a7165ad856b919efd47ad3ac7ebe90c9 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 23 Apr 2025 12:06:32 -0700 Subject: [PATCH 04/32] reduce diffs --- .../schema_registry/__init__.py | 69 +++++++++++++++--- .../schema_registry/_sync/avro.py | 2 +- .../schema_registry/_sync/json_schema.py | 2 +- .../schema_registry/_sync/protobuf.py | 3 +- .../schema_registry/common/__init__.py | 70 ------------------- .../schema_registry/common/avro.py | 14 ++++ .../schema_registry/common/json_schema.py | 14 ++++ .../schema_registry/common/protobuf.py | 13 ++++ 8 files changed, 104 insertions(+), 83 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 97a869b14..e4ad4be17 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Optional from .schema_registry_client import ( ConfigCompatibilityLevel, @@ -34,16 +35,9 @@ ServerConfig ) -from .common import ( - _MAGIC_BYTE, - topic_subject_name_strategy, - topic_record_subject_name_strategy, - record_subject_name_strategy, - reference_subject_name_strategy -) +_MAGIC_BYTE = 0 __all__ = [ - "_MAGIC_BYTE", "ConfigCompatibilityLevel", "Metadata", "MetadataProperties", @@ -61,6 +55,61 @@ "ServerConfig", "topic_subject_name_strategy", "topic_record_subject_name_strategy", - "record_subject_name_strategy", - "reference_subject_name_strategy", + "record_subject_name_strategy" ] + + +def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {topic}-key|value. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return ctx.topic + "-" + ctx.field + + +def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {topic}-{record_name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return ctx.topic + "-" + record_name if record_name is not None else None + + +def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: + """ + Constructs a subject name in the form of {record_name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + record_name (Optional[str]): Record name. + + """ + return record_name if record_name is not None else None + + +def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: + """ + Constructs a subject reference name in the form of {reference name}. + + Args: + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + schema_ref (SchemaReference): SchemaReference instance. + + """ + return schema_ref.name if schema_ref is not None else None diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 38ec36402..e3ca57b30 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -30,7 +30,7 @@ SchemaRegistryClient) from confluent_kafka.serialization import (SerializationError, SerializationContext) -from confluent_kafka.schema_registry.common import _ContextStringIO +from confluent_kafka.schema_registry.common.avro import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 819618188..36107b8d0 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -33,7 +33,7 @@ from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform ) -from confluent_kafka.schema_registry.common import _ContextStringIO +from confluent_kafka.schema_registry.common.json_schema import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ ParsedSchemaCache diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index a0ae75b9a..b3d940907 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -26,9 +26,10 @@ from google.protobuf.message import DecodeError, Message from google.protobuf.message_factory import GetMessageClass -from confluent_kafka.schema_registry.common import (_MAGIC_BYTE, _ContextStringIO, +from confluent_kafka.schema_registry import (_MAGIC_BYTE, reference_subject_name_strategy, topic_subject_name_strategy) +from confluent_kafka.schema_registry.common.protobuf import _ContextStringIO from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform from confluent_kafka.schema_registry.rule_registry import RuleRegistry diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py index c1e14957a..309c452f1 100644 --- a/src/confluent_kafka/schema_registry/common/__init__.py +++ b/src/confluent_kafka/schema_registry/common/__init__.py @@ -19,73 +19,3 @@ from typing import Optional from .schema_registry_client import SchemaReference - -_MAGIC_BYTE = 0 - -def topic_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {topic}-key|value. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return ctx.topic + "-" + ctx.field - - -def topic_record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {topic}-{record_name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return ctx.topic + "-" + record_name if record_name is not None else None - - -def record_subject_name_strategy(ctx, record_name: Optional[str]) -> Optional[str]: - """ - Constructs a subject name in the form of {record_name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - record_name (Optional[str]): Record name. - - """ - return record_name if record_name is not None else None - - -def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optional[str]: - """ - Constructs a subject reference name in the form of {reference name}. - - Args: - ctx (SerializationContext): Metadata pertaining to the serialization - operation. - - schema_ref (SchemaReference): SchemaReference instance. - - """ - return schema_ref.name if schema_ref is not None else None - - -class _ContextStringIO(BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py index 602cf933b..7e038f6c7 100644 --- a/src/confluent_kafka/schema_registry/common/avro.py +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -2,6 +2,7 @@ import re from collections import defaultdict from copy import deepcopy +from io import BytesIO from json import loads from typing import Dict, Union, Optional, Set @@ -42,6 +43,19 @@ AvroSchema = Union[str, list, dict] +class _ContextStringIO(BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + def _schema_loads(schema_str: str) -> Schema: """ Instantiate a Schema instance from a declaration string. diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py index 902df18a9..ef789fe69 100644 --- a/src/confluent_kafka/schema_registry/common/json_schema.py +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -42,6 +42,20 @@ DEFAULT_SPEC = referencing.jsonschema.DRAFT7 + +class _ContextStringIO(BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + def _retrieve_via_httpx(uri: str): response = httpx.get(uri) return Resource.from_contents( diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index f805ae9c6..8f4bbbaae 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -67,6 +67,19 @@ def _bytes(v: int) -> str: return chr(v) +class _ContextStringIO(io.BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + def _create_index_array(msg_desc: Descriptor) -> List[int]: """ Creates an index array specifying the location of msg_desc in From 827f14bab61cc6af2c79da67beb80b7f722485fd Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 23 Apr 2025 12:10:05 -0700 Subject: [PATCH 05/32] refactor --- .../schema_registry/_sync/__init__.py | 17 --------------- .../schema_registry/_sync/avro.py | 4 ++-- .../schema_registry/_sync/json_schema.py | 3 +-- .../schema_registry/_sync/protobuf.py | 4 ++-- .../schema_registry/common/__init__.py | 21 ------------------- 5 files changed, 5 insertions(+), 44 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/__init__.py b/src/confluent_kafka/schema_registry/_sync/__init__.py index 2b4389a06..e69de29bb 100644 --- a/src/confluent_kafka/schema_registry/_sync/__init__.py +++ b/src/confluent_kafka/schema_registry/_sync/__init__.py @@ -1,17 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright 2020 Confluent Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index e3ca57b30..00b873f98 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -21,7 +21,8 @@ from fastavro import schemaless_reader, schemaless_writer -from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, get_inline_tags, parse_schema_with_repo, transform +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ + get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO from confluent_kafka.schema_registry import (_MAGIC_BYTE, Schema, @@ -30,7 +31,6 @@ SchemaRegistryClient) from confluent_kafka.serialization import (SerializationError, SerializationContext) -from confluent_kafka.schema_registry.common.avro import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 36107b8d0..42cfc22f8 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -31,9 +31,8 @@ RuleMode, SchemaRegistryClient) from confluent_kafka.schema_registry.common.json_schema import ( - DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO ) -from confluent_kafka.schema_registry.common.json_schema import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ ParsedSchemaCache diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index b3d940907..8949b7c30 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -29,9 +29,9 @@ from confluent_kafka.schema_registry import (_MAGIC_BYTE, reference_subject_name_strategy, topic_subject_name_strategy) -from confluent_kafka.schema_registry.common.protobuf import _ContextStringIO from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient -from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ + _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry import (Schema, SchemaReference, diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py index 309c452f1..e69de29bb 100644 --- a/src/confluent_kafka/schema_registry/common/__init__.py +++ b/src/confluent_kafka/schema_registry/common/__init__.py @@ -1,21 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright 2020 Confluent Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from io import BytesIO -from typing import Optional - -from .schema_registry_client import SchemaReference From 5cb44ddffd6edcf2e3bb9ee2a5688bcc97b1d45d Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 23 Apr 2025 12:32:27 -0700 Subject: [PATCH 06/32] formatting --- .../schema_registry/__init__.py | 66 +++++++++---------- .../schema_registry/_sync/avro.py | 13 ++-- .../schema_registry/_sync/json_schema.py | 7 +- .../schema_registry/_sync/protobuf.py | 15 ++--- .../_sync/schema_registry_client.py | 12 ++-- .../schema_registry/_sync/serde.py | 5 +- src/confluent_kafka/schema_registry/avro.py | 4 +- .../schema_registry/common/protobuf.py | 4 +- .../common/schema_registry_client.py | 2 + .../schema_registry/common/serde.py | 2 + .../schema_registry/json_schema.py | 4 +- .../schema_registry/protobuf.py | 4 +- .../schema_registry/schema_registry_client.py | 7 +- src/confluent_kafka/schema_registry/serde.py | 4 +- 14 files changed, 79 insertions(+), 70 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index e4ad4be17..d6b9fd197 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -18,44 +18,44 @@ from typing import Optional from .schema_registry_client import ( - ConfigCompatibilityLevel, - Metadata, - MetadataProperties, - MetadataTags, - RegisteredSchema, - Rule, - RuleKind, - RuleMode, - RuleParams, - RuleSet, - Schema, - SchemaRegistryClient, - SchemaRegistryError, - SchemaReference, - ServerConfig + ConfigCompatibilityLevel, + Metadata, + MetadataProperties, + MetadataTags, + RegisteredSchema, + Rule, + RuleKind, + RuleMode, + RuleParams, + RuleSet, + Schema, + SchemaRegistryClient, + SchemaRegistryError, + SchemaReference, + ServerConfig ) _MAGIC_BYTE = 0 __all__ = [ - "ConfigCompatibilityLevel", - "Metadata", - "MetadataProperties", - "MetadataTags", - "RegisteredSchema", - "Rule", - "RuleKind", - "RuleMode", - "RuleParams", - "RuleSet", - "Schema", - "SchemaRegistryClient", - "SchemaRegistryError", - "SchemaReference", - "ServerConfig", - "topic_subject_name_strategy", - "topic_record_subject_name_strategy", - "record_subject_name_strategy" + "ConfigCompatibilityLevel", + "Metadata", + "MetadataProperties", + "MetadataTags", + "RegisteredSchema", + "Rule", + "RuleKind", + "RuleMode", + "RuleParams", + "RuleSet", + "Schema", + "SchemaRegistryClient", + "SchemaRegistryError", + "SchemaReference", + "ServerConfig", + "topic_subject_name_strategy", + "topic_record_subject_name_strategy", + "record_subject_name_strategy" ] diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 00b873f98..57c792a18 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -25,10 +25,10 @@ get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleMode, - SchemaRegistryClient) + Schema, + topic_subject_name_strategy, + RuleMode, + SchemaRegistryClient) from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry @@ -40,6 +40,7 @@ 'AvroDeserializer', ] + def _resolve_named_schema( schema: Schema, schema_registry_client: SchemaRegistryClient ) -> Dict[str, AvroSchema]: @@ -319,7 +320,7 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - if latest_schema is not None: parsed_schema = self._get_parsed_schema(latest_schema.schema) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, get_inline_tags(parsed_schema), @@ -556,7 +557,7 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) reader_schema, self._return_record_name) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, get_inline_tags(reader_schema), diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 42cfc22f8..b0c8815c2 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -45,6 +45,7 @@ 'JSONDeserializer' ] + def _resolve_named_schema( schema: Schema, schema_registry_client: SchemaRegistryClient, ref_registry: Optional[Registry] = None @@ -70,7 +71,6 @@ def _resolve_named_schema( return ref_registry - class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -338,7 +338,7 @@ def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) - root_resource = Resource.from_contents( parsed_schema, default_specification=DEFAULT_SPEC) ref_resolver = ref_registry.resolver_with_root(root_resource) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, None, @@ -601,7 +601,8 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - reader_root_resource = Resource.from_contents( reader_schema, default_specification=DEFAULT_SPEC) reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, "$", message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 8949b7c30..83e4324aa 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -27,15 +27,15 @@ from google.protobuf.message_factory import GetMessageClass from confluent_kafka.schema_registry import (_MAGIC_BYTE, - reference_subject_name_strategy, - topic_subject_name_strategy) + reference_subject_name_strategy, + topic_subject_name_strategy) from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry import (Schema, - SchemaReference, - RuleMode) + SchemaReference, + RuleMode) from confluent_kafka.serialization import SerializationError, \ SerializationContext @@ -47,6 +47,7 @@ 'ProtobufDeserializer', ] + def _resolve_named_schema( schema: Schema, schema_registry_client: SchemaRegistryClient, @@ -73,7 +74,6 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) - class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -428,7 +428,7 @@ def __serialize(self, message: Message, ctx: Optional[SerializationContext] = No fd_proto, pool = self._get_parsed_schema(latest_schema.schema) fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, message, None, @@ -460,7 +460,6 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool - class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. @@ -755,7 +754,7 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - except DecodeError as e: raise SerializationError(str(e)) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) msg = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, msg, None, diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 471ff0857..8d259873e 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -32,13 +32,13 @@ from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError from confluent_kafka.schema_registry.common.schema_registry_client import ( - RegisteredSchema, - ServerConfig, - is_success, - is_retriable, + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, _BearerFieldProvider, full_jitter, - _SchemaCache, + _SchemaCache, Schema, _StaticFieldProvider, ) @@ -98,7 +98,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin def get_bearer_fields(self) -> dict: return { - 'bearer.auth.token': self.get_access_token(), + 'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, 'bearer.auth.identity.pool.id': self.identity_pool } diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index d21fcaa3c..a0481f9b0 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -20,7 +20,9 @@ from typing import List, Optional, Set, Dict, Any from confluent_kafka.schema_registry import RegisteredSchema -from confluent_kafka.schema_registry.common.serde import ErrorAction, FieldTransformer, Migration, NoneAction, RuleAction, RuleConditionError, RuleContext, RuleError +from confluent_kafka.schema_registry.common.serde import ErrorAction, \ + FieldTransformer, Migration, NoneAction, RuleAction, \ + RuleConditionError, RuleContext, RuleError from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ Rule, RuleKind, Schema, RuleSet from confluent_kafka.serialization import Serializer, Deserializer, \ @@ -34,6 +36,7 @@ log = logging.getLogger(__name__) + class BaseSerde(object): __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', '_registry', '_rule_registry', '_subject_name_func', diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index e570f18f8..0da66230e 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.avro import * -from ._sync.avro import * +from .common.avro import * # noqa +from ._sync.avro import * # noqa diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index 8f4bbbaae..3b27940fc 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -324,8 +324,8 @@ def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: def _is_builtin(name: str) -> bool: return name.startswith('confluent/') or \ - name.startswith('google/protobuf/') or \ - name.startswith('google/type/') + name.startswith('google/protobuf/') or \ + name.startswith('google/type/') def decimalToProtobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index 812929ca0..a9ab11756 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -50,6 +50,7 @@ VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] + class _BearerFieldProvider(metaclass=abc.ABCMeta): @abc.abstractmethod def get_bearer_fields(self) -> dict: @@ -273,6 +274,7 @@ def clear(self): self.rs_version_index.clear() self.rs_schema_index.clear() + T = TypeVar("T") diff --git a/src/confluent_kafka/schema_registry/common/serde.py b/src/confluent_kafka/schema_registry/common/serde.py index 82f9d23e6..4244415c1 100644 --- a/src/confluent_kafka/schema_registry/common/serde.py +++ b/src/confluent_kafka/schema_registry/common/serde.py @@ -269,8 +269,10 @@ def __init__( self.source = source self.target = target + T = TypeVar("T") + class ParsedSchemaCache(object): """ Thread-safe cache for parsed schemas diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index cff914d96..2371c618e 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.json_schema import * -from ._sync.json_schema import * +from .common.json_schema import * # noqa +from ._sync.json_schema import * # noqa diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index 135afca16..ba0f03a90 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.protobuf import * -from ._sync.protobuf import * +from .common.protobuf import * # noqa +from ._sync.protobuf import * # noqa diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index 1763206ff..e9a0eb3e2 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -15,7 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.schema_registry_client import * -from ._sync.schema_registry_client import * -from .error import SchemaRegistryError +from .common.schema_registry_client import * # noqa +from ._sync.schema_registry_client import * # noqa + +from .error import SchemaRegistryError # noqa diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 3f351349c..87037dc17 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -16,5 +16,5 @@ # limitations under the License. # -from .common.serde import * -from ._sync.serde import * +from .common.serde import * # noqa +from ._sync.serde import * # noqa From 9cf1b088b86d934c39e9e89a940bc1ef758e8726 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 21:11:41 -0700 Subject: [PATCH 07/32] style fix --- .../schema_registry/__init__.py | 44 +++++++++---------- .../schema_registry/_sync/avro.py | 25 +++++------ .../schema_registry/_sync/json_schema.py | 3 +- .../schema_registry/_sync/protobuf.py | 2 +- src/confluent_kafka/schema_registry/avro.py | 4 +- src/confluent_kafka/schema_registry/error.py | 1 + .../schema_registry/json_schema.py | 4 +- .../schema_registry/protobuf.py | 4 +- .../schema_registry/schema_registry_client.py | 6 +-- src/confluent_kafka/schema_registry/serde.py | 4 +- 10 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 8e6310342..7d87eb608 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -45,28 +45,28 @@ _MAGIC_BYTE_V1 = 1 __all__ = [ - "ConfigCompatibilityLevel", - "Metadata", - "MetadataProperties", - "MetadataTags", - "RegisteredSchema", - "Rule", - "RuleKind", - "RuleMode", - "RuleParams", - "RuleSet", - "Schema", - "SchemaRegistryClient", - "SchemaRegistryError", - "SchemaReference", - "ServerConfig", - "topic_subject_name_strategy", - "topic_record_subject_name_strategy", - "record_subject_name_strategy", - "header_schema_id_serializer", - "prefix_schema_id_serializer", - "dual_schema_id_deserializer", - "prefix_schema_id_deserializer" + "ConfigCompatibilityLevel", + "Metadata", + "MetadataProperties", + "MetadataTags", + "RegisteredSchema", + "Rule", + "RuleKind", + "RuleMode", + "RuleParams", + "RuleSet", + "Schema", + "SchemaRegistryClient", + "SchemaRegistryError", + "SchemaReference", + "ServerConfig", + "topic_subject_name_strategy", + "topic_record_subject_name_strategy", + "record_subject_name_strategy", + "header_schema_id_serializer", + "prefix_schema_id_serializer", + "dual_schema_id_deserializer", + "prefix_schema_id_deserializer" ] diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 584691a96..43edf04e1 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -28,8 +28,8 @@ Schema, topic_subject_name_strategy, RuleMode, - SchemaRegistryClient, - prefix_schema_id_serializer, + SchemaRegistryClient, + prefix_schema_id_serializer, dual_schema_id_deserializer) from confluent_kafka.serialization import (SerializationError, SerializationContext) @@ -244,10 +244,10 @@ def __init__( self._subject_name_func = conf_copy.pop('subject.name.strategy') if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._schema_id_deserializer): - raise ValueError("schema.id.deserializer must be callable") + + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" @@ -480,10 +480,9 @@ def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') - if not callable(self._schema_id_serializer): - raise ValueError("schema.id.serializer must be callable") - + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" @@ -579,11 +578,7 @@ def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) reader_schema, self._return_record_name) - - - - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, get_inline_tags(reader_schema), diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 9b3638199..bb642c9a3 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -612,7 +612,8 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - reader_root_resource = Resource.from_contents( reader_schema, default_specification=DEFAULT_SPEC) reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, "$", message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 2ca24946c..1204c2230 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -664,7 +664,7 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - except DecodeError as e: raise SerializationError(str(e)) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) msg = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, msg, None, diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 0da66230e..0a9ba2db2 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.avro import * # noqa -from ._sync.avro import * # noqa +from .common.avro import * # noqa +from ._sync.avro import * # noqa diff --git a/src/confluent_kafka/schema_registry/error.py b/src/confluent_kafka/schema_registry/error.py index e474cc055..2aa4d6dcd 100644 --- a/src/confluent_kafka/schema_registry/error.py +++ b/src/confluent_kafka/schema_registry/error.py @@ -57,6 +57,7 @@ def __str__(self): class OAuthTokenError(Exception): """Raised when an OAuth token cannot be retrieved.""" + def __init__(self, message, status_code=None, response_text=None): self.message = message self.status_code = status_code diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index 2371c618e..e60c8eafd 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.json_schema import * # noqa -from ._sync.json_schema import * # noqa +from .common.json_schema import * # noqa +from ._sync.json_schema import * # noqa diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index ba0f03a90..c781e47be 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -15,5 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .common.protobuf import * # noqa -from ._sync.protobuf import * # noqa +from .common.protobuf import * # noqa +from ._sync.protobuf import * # noqa diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index e9a0eb3e2..b7a008097 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -16,7 +16,7 @@ # limitations under the License. -from .common.schema_registry_client import * # noqa -from ._sync.schema_registry_client import * # noqa +from .common.schema_registry_client import * # noqa +from ._sync.schema_registry_client import * # noqa -from .error import SchemaRegistryError # noqa +from .error import SchemaRegistryError # noqa diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 87037dc17..25fe4a48f 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -16,5 +16,5 @@ # limitations under the License. # -from .common.serde import * # noqa -from ._sync.serde import * # noqa +from .common.serde import * # noqa +from ._sync.serde import * # noqa From 09d8ebce2f4d57cc1f9458a712333d0692d43098 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 21:23:31 -0700 Subject: [PATCH 08/32] fix flake8 --- setup.py | 2 +- src/confluent_kafka/schema_registry/_sync/avro.py | 2 +- src/confluent_kafka/schema_registry/_sync/json_schema.py | 4 +--- src/confluent_kafka/schema_registry/_sync/protobuf.py | 4 +--- src/confluent_kafka/schema_registry/common/protobuf.py | 4 ++-- .../schema_registry/common/schema_registry_client.py | 2 +- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index c401e6ad3..1141e7cee 100755 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import os from setuptools import setup -from distutils.core import Extension +from setuptools import Extension import platform work_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 43edf04e1..f43071fbe 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -16,7 +16,7 @@ # limitations under the License. from json import loads -from struct import pack, unpack +from struct import pack from typing import Dict, Union, Optional, Callable from fastavro import schemaless_reader, schemaless_writer diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index bb642c9a3..61c46cd96 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -16,7 +16,6 @@ # limitations under the License. import json -import struct from typing import Union, Optional, Tuple, Callable from cachetools import LRUCache @@ -25,8 +24,7 @@ from jsonschema.validators import validator_for from referencing import Registry, Resource -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, +from confluent_kafka.schema_registry import (Schema, topic_subject_name_strategy, RuleMode, SchemaRegistryClient, prefix_schema_id_serializer, diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 1204c2230..7d4070465 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -16,7 +16,6 @@ # limitations under the License. import io -import struct import warnings from typing import Set, List, Union, Optional, Tuple @@ -26,8 +25,7 @@ from google.protobuf.message import DecodeError, Message from google.protobuf.message_factory import GetMessageClass -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - reference_subject_name_strategy, +from confluent_kafka.schema_registry import (reference_subject_name_strategy, topic_subject_name_strategy, prefix_schema_id_serializer, dual_schema_id_deserializer) from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py index 832701ef3..3a55bc5de 100644 --- a/src/confluent_kafka/schema_registry/common/protobuf.py +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -40,8 +40,8 @@ 'get_inline_tags', '_disjoint', '_is_builtin', - 'decimalToProtobuf', - 'protobufToDecimal' + 'decimal_to_protobuf', + 'protobuf_to_decimal' ] # Convert an int to bytes (inverse of ord()) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index 3630afc9d..edbd38d02 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -23,7 +23,7 @@ from enum import Enum from threading import Lock from typing import List, Dict, Type, TypeVar, \ - cast, Optional, Any + cast, Optional, Any, Tuple __all__ = [ 'VALID_AUTH_PROVIDERS', From 4ae7877593a20bfc18f39a83618cf4b9daf726da Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 21:45:30 -0700 Subject: [PATCH 09/32] fix --- src/confluent_kafka/schema_registry/_sync/avro.py | 2 -- .../schema_registry/_sync/serde.py | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index f43071fbe..a6ee94e9b 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -344,8 +344,6 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 parsed_schema = self._parsed_schema with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index a0481f9b0..84ce15726 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.serde import ErrorAction, \ FieldTransformer, Migration, NoneAction, RuleAction, \ - RuleConditionError, RuleContext, RuleError + RuleConditionError, RuleContext, RuleError, SchemaId from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ Rule, RuleKind, Schema, RuleSet from confluent_kafka.serialization import Serializer, Deserializer, \ @@ -181,11 +181,20 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA class BaseSerializer(BaseSerde, Serializer): - __slots__ = ['_auto_register', '_normalize_schemas'] + __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] class BaseDeserializer(BaseSerde, Deserializer): - __slots__ = [] + __slots__ = ['_schema_id_deserializer'] + + def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: + if schema_id.id is not None: + return self._registry.get_schema(schema_id.id, subject, fmt) + elif schema_id.guid is not None: + return self._registry.get_schema_by_guid(str(schema_id.guid), fmt) + else: + raise SerializationError("Schema ID or GUID is not set") def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: if rule_set is None: From c4cfbcb884b64f41651016fc63ddac740ba80c7e Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 22:00:51 -0700 Subject: [PATCH 10/32] fix flake8 --- src/confluent_kafka/schema_registry/_sync/avro.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index a6ee94e9b..651dee7ca 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -16,7 +16,6 @@ # limitations under the License. from json import loads -from struct import pack from typing import Dict, Union, Optional, Callable from fastavro import schemaless_reader, schemaless_writer @@ -24,8 +23,7 @@ from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO, AVRO_TYPE -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, +from confluent_kafka.schema_registry import (Schema, topic_subject_name_strategy, RuleMode, SchemaRegistryClient, From 1a6fa96820c89016d37c7a164168958daca43467 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 16:26:55 -0700 Subject: [PATCH 11/32] revert --- Makefile | 2 +- tools/style-format.sh | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 1798c9d4f..3615e2b93 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ docs: style-check: @(tools/style-format.sh \ - $$(git ls-tree -r --name-only HEAD | egrep '\.(py)$$') ) + $$(git ls-tree -r --name-only HEAD | egrep '\.(c|h|py)$$') ) style-check-changed: @(tools/style-format.sh \ diff --git a/tools/style-format.sh b/tools/style-format.sh index e1de4272f..a686cc6f0 100755 --- a/tools/style-format.sh +++ b/tools/style-format.sh @@ -26,11 +26,11 @@ else fix=0 fi -# clang_format_version=$(${CLANG_FORMAT} --version | sed -Ee 's/.*version ([[:digit:]]+)\.[[:digit:]]+\.[[:digit:]]+.*/\1/') -# if [[ $clang_format_version != "10" ]] ; then -# echo "$0: clang-format version 10, '$clang_format_version' detected" -# exit 1 -# fi +clang_format_version=$(${CLANG_FORMAT} --version | sed -Ee 's/.*version ([[:digit:]]+)\.[[:digit:]]+\.[[:digit:]]+.*/\1/') +if [[ $clang_format_version != "10" ]] ; then + echo "$0: clang-format version 10, '$clang_format_version' detected" + exit 1 +fi # Get list of files from .formatignore to ignore formatting for. ignore_files=( $(grep '^[^#]..' .formatignore) ) From bf64e651ae3f661efd087b44b9c8926a9481edcd Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 16:30:58 -0700 Subject: [PATCH 12/32] reduce diff --- .../schema_registry/__init__.py | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 7d87eb608..2d81f44ac 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -19,21 +19,21 @@ from typing import Optional from .schema_registry_client import ( - ConfigCompatibilityLevel, - Metadata, - MetadataProperties, - MetadataTags, - RegisteredSchema, - Rule, - RuleKind, - RuleMode, - RuleParams, - RuleSet, - Schema, - SchemaRegistryClient, - SchemaRegistryError, - SchemaReference, - ServerConfig + ConfigCompatibilityLevel, + Metadata, + MetadataProperties, + MetadataTags, + RegisteredSchema, + Rule, + RuleKind, + RuleMode, + RuleParams, + RuleSet, + Schema, + SchemaRegistryClient, + SchemaRegistryError, + SchemaReference, + ServerConfig ) from ..serialization import SerializationError, MessageField @@ -45,28 +45,28 @@ _MAGIC_BYTE_V1 = 1 __all__ = [ - "ConfigCompatibilityLevel", - "Metadata", - "MetadataProperties", - "MetadataTags", - "RegisteredSchema", - "Rule", - "RuleKind", - "RuleMode", - "RuleParams", - "RuleSet", - "Schema", - "SchemaRegistryClient", - "SchemaRegistryError", - "SchemaReference", - "ServerConfig", - "topic_subject_name_strategy", - "topic_record_subject_name_strategy", - "record_subject_name_strategy", - "header_schema_id_serializer", - "prefix_schema_id_serializer", - "dual_schema_id_deserializer", - "prefix_schema_id_deserializer" + "ConfigCompatibilityLevel", + "Metadata", + "MetadataProperties", + "MetadataTags", + "RegisteredSchema", + "Rule", + "RuleKind", + "RuleMode", + "RuleParams", + "RuleSet", + "Schema", + "SchemaRegistryClient", + "SchemaRegistryError", + "SchemaReference", + "ServerConfig", + "topic_subject_name_strategy", + "topic_record_subject_name_strategy", + "record_subject_name_strategy", + "header_schema_id_serializer", + "prefix_schema_id_serializer", + "dual_schema_id_deserializer", + "prefix_schema_id_deserializer" ] From 95da806c087b21b95b9d7dcf272574f6b76b9726 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 17 Apr 2025 14:36:01 -0700 Subject: [PATCH 13/32] Implement async variants of SR clients and serdes --- LICENSE | 33 + pyproject.toml | 3 + requirements/requirements-tests.txt | 1 + setup.py | 132 +- .../schema_registry/_async/__init__.py | 17 + .../schema_registry/_async/avro.py | 577 +++++++++ .../schema_registry/_async/json_schema.py | 643 ++++++++++ .../schema_registry/_async/protobuf.py | 801 ++++++++++++ .../_async/schema_registry_client.py | 1115 +++++++++++++++++ .../schema_registry/_async/serde.py | 252 ++++ tests/integration/cluster_fixture.py | 71 ++ .../schema_registry/_async/__init__.py | 0 .../schema_registry/_async/test_api_client.py | 491 ++++++++ .../_async/test_avro_serializers.py | 350 ++++++ .../_async/test_json_serializers.py | 490 ++++++++ .../_async/test_proto_serializers.py | 149 +++ 16 files changed, 5123 insertions(+), 2 deletions(-) create mode 100644 src/confluent_kafka/schema_registry/_async/__init__.py create mode 100644 src/confluent_kafka/schema_registry/_async/avro.py create mode 100644 src/confluent_kafka/schema_registry/_async/json_schema.py create mode 100644 src/confluent_kafka/schema_registry/_async/protobuf.py create mode 100644 src/confluent_kafka/schema_registry/_async/schema_registry_client.py create mode 100644 src/confluent_kafka/schema_registry/_async/serde.py create mode 100644 tests/integration/schema_registry/_async/__init__.py create mode 100644 tests/integration/schema_registry/_async/test_api_client.py create mode 100644 tests/integration/schema_registry/_async/test_avro_serializers.py create mode 100644 tests/integration/schema_registry/_async/test_json_serializers.py create mode 100644 tests/integration/schema_registry/_async/test_proto_serializers.py diff --git a/LICENSE b/LICENSE index 521517282..02f2aa004 100644 --- a/LICENSE +++ b/LICENSE @@ -652,3 +652,36 @@ For the files wingetopt.c wingetopt.h downloaded from https://github.com/alex85k */ + +LICENSE.unasync +-------------------------------------------------------------- +For unasync code in setup.py, derived from +https://github.com/encode/httpcore/blob/ae46dfbd4330eefaa9cd6ab1560dec18a1d0bcb8/scripts/unasync.py + +Copyright © 2020, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f9df81c65..129373f63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,3 +73,6 @@ optional-dependencies.all = { file = [ "requirements/requirements-avro.txt", "requirements/requirements-json.txt", "requirements/requirements-protobuf.txt"] } + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index 84e6818ca..932f9a427 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -9,3 +9,4 @@ requests-mock respx pytest_cov pluggy<1.6.0 +pytest-asyncio diff --git a/setup.py b/setup.py index 1141e7cee..91b0fb9f2 100755 --- a/setup.py +++ b/setup.py @@ -1,9 +1,13 @@ #!/usr/bin/env python import os +import platform +import re +import sys from setuptools import setup from setuptools import Extension -import platform +from setuptools.command.build_py import build_py as _build_py +from pprint import pprint work_dir = os.path.dirname(os.path.realpath(__file__)) mod_dir = os.path.join(work_dir, 'src', 'confluent_kafka') @@ -25,4 +29,128 @@ os.path.join(ext_dir, 'AdminTypes.c'), os.path.join(ext_dir, 'Admin.c')]) -setup(ext_modules=[module]) +SUBS = [ + ('from confluent_kafka.schema_registry.common import asyncinit', ''), + ('@asyncinit', ''), + ('import asyncio', ''), + ('asyncio.sleep', 'time.sleep'), + + ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), + ('_Async([A-Z][A-Za-z0-9_]*)', r'_\2'), + ('async_([a-z][A-Za-z0-9_]*)', r'\2'), + + ('async def', 'def'), + ('await ', ''), + ('aclose', 'close'), + ('__aenter__', '__enter__'), + ('__aexit__', '__exit__'), + ('__aiter__', '__iter__'), +] +COMPILED_SUBS = [ + (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) + for regex, repl in SUBS +] + +USED_SUBS = set() + +def unasync_line(line): + for index, (regex, repl) in enumerate(COMPILED_SUBS): + old_line = line + line = re.sub(regex, repl, line) + if old_line != line: + USED_SUBS.add(index) + return line + + +def unasync_file(in_path, out_path): + with open(in_path, "r") as in_file: + with open(out_path, "w", newline="") as out_file: + for line in in_file.readlines(): + line = unasync_line(line) + out_file.write(line) + + +def unasync_file_check(in_path, out_path): + with open(in_path, "r") as in_file: + with open(out_path, "r") as out_file: + for in_line, out_line in zip(in_file.readlines(), out_file.readlines()): + expected = unasync_line(in_line) + if out_line != expected: + print(f'unasync mismatch between {in_path!r} and {out_path!r}') + print(f'Async code: {in_line!r}') + print(f'Expected sync code: {expected!r}') + print(f'Actual sync code: {out_line!r}') + sys.exit(1) + + +def unasync_dir(in_dir, out_dir, check_only=False): + for dirpath, dirnames, filenames in os.walk(in_dir): + for filename in filenames: + if not filename.endswith('.py'): + continue + rel_dir = os.path.relpath(dirpath, in_dir) + in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) + out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) + print(in_path, '->', out_path) + if check_only: + unasync_file_check(in_path, out_path) + else: + unasync_file(in_path, out_path) + +def unasync(): + unasync_dir( + "src/confluent_kafka/schema_registry/_async", + "src/confluent_kafka/schema_registry/_sync", + check_only=False + ) + unasync_dir( + "tests/integration/schema_registry/_async", + "tests/integration/schema_registry/_sync", + check_only=False + ) + + + if len(USED_SUBS) != len(SUBS): + unused_subs = [SUBS[i] for i in range(len(SUBS)) if i not in USED_SUBS] + + print("These patterns were not used:") + pprint(unused_subs) + + +class build_py(_build_py): + """ + Subclass build_py from setuptools to modify its behavior. + + Convert files in _async dir from being asynchronous to synchronous + and saves them to the specified output directory. + """ + + def run(self): + self._updated_files = [] + + # Base class code + if self.py_modules: + self.build_modules() + if self.packages: + self.build_packages() + self.build_package_data() + + # Our modification + unasync() + + # Remaining base class code + self.byte_compile(self.get_outputs(include_bytecode=0)) + + def build_module(self, module, module_file, package): + outfile, copied = super().build_module(module, module_file, package) + if copied: + self._updated_files.append(outfile) + return outfile, copied + + +setup( + ext_modules=[module], + cmdclass={ + 'build_py': build_py, + } +) diff --git a/src/confluent_kafka/schema_registry/_async/__init__.py b/src/confluent_kafka/schema_registry/_async/__init__.py new file mode 100644 index 000000000..2b4389a06 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py new file mode 100644 index 000000000..713182408 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from json import loads +from struct import pack, unpack +from typing import Dict, Union, Optional, Callable + +from fastavro import schemaless_reader, schemaless_writer + +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, get_inline_tags, parse_schema_with_repo, transform + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, + AsyncSchemaRegistryClient) +from confluent_kafka.serialization import (SerializationError, + SerializationContext) +from confluent_kafka.schema_registry.common import _ContextStringIO, asyncinit +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache + + +async def _resolve_named_schema( + schema: Schema, schema_registry_client: AsyncSchemaRegistryClient +) -> Dict[str, AvroSchema]: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :return: named_schemas dict. + """ + named_schemas = {} + if schema.references is not None: + for ref in schema.references: + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = await _resolve_named_schema(referenced_schema.schema, schema_registry_client) + parsed_schema = parse_schema_with_repo( + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + named_schemas.update(ref_named_schemas) + named_schemas[ref.name] = parsed_schema + return named_schemas + +@asyncinit +class AsyncAvroSerializer(AsyncBaseSerializer): + """ + Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+==================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Note: + Prior to serialization, all values must first be converted to + a dict instance. This may handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with AvroSerializer. + + See ``avro_producer.py`` in the examples directory for example usage. + + Note: + Tuple notation can be used to determine which branch of an ambiguous union to take. + + See `fastavro notation `_ + + Args: + schema_registry_client (SchemaRegistryClient): Schema Registry client instance. + + schema_str (str or Schema): + Avro `Schema Declaration. `_ + Accepts either a string or a :py:class:`Schema` instance. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. + + conf (dict): AvroSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_schema', + '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + async def __init__( + self, + schema_registry_client: AsyncSchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + schema = None + + self._registry = schema_registry_client + self._schema_id = None + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + parsed_schema = await self._get_parsed_schema(schema) + + if isinstance(parsed_schema, list): + # if parsed_schema is a list, we have an Avro union and there + # is no valid schema name. This is fine because the only use of + # schema_name is for supplying the subject name to the registry + # and union types should use topic_subject_name_strategy, which + # just discards the schema name anyway + schema_name = None + else: + # The Avro spec states primitives have a name equal to their type + # i.e. {"type": "string"} has a name of string. + # This function does not comply. + # https://github.com/fastavro/fastavro/issues/415 + schema_dict = loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + parsed_schema = None + + self._schema = schema + self._schema_name = schema_name + self._parsed_schema = parsed_schema + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to Avro binary format, prepending it with Confluent + Schema Registry framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata pertaining to the serialization operation. + + Raises: + SerializerError: If any error occurs serializing obj. + SchemaRegistryError: If there was an error registering the schema with + Schema Registry, or auto.register.schemas is + false and the schema was not registered. + + Returns: + bytes: Confluent Schema Registry encoded Avro bytes + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = await self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = await self._registry.register_schema( + subject, self._schema, self._normalize_schemas) + else: + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + parsed_schema = await self._get_parsed_schema(latest_schema.schema) + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, parsed_schema, msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, get_inline_tags(parsed_schema), + field_transformer) + else: + parsed_schema = self._parsed_schema + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the record to the rest of the buffer + schemaless_writer(fo, parsed_schema, value) + + return fo.getvalue() + + async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = await _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema + + +@asyncinit +class AsyncAvroDeserializer(AsyncBaseDeserializer): + """ + Deserializer for Avro binary encoded data with Confluent Schema Registry + framing. + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Note: + By default, Avro complex types are returned as dicts. This behavior can + be overridden by registering a callable ``from_dict`` with the deserializer to + convert the dicts to the desired type. + + See ``avro_consumer.py`` in the examples directory in the examples + directory for example usage. + + Args: + schema_registry_client (SchemaRegistryClient): Confluent Schema Registry + client instance. + + schema_str (str, Schema, optional): Avro reader schema declaration Accepts + either a string or a :py:class:`Schema` instance. If not provided, the + writer schema will be used as the reader schema. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to an instance of some object. + + return_record_name (bool): If True, when reading a union of records, the result will + be a tuple where the first value is the name of the record and the second value is + the record itself. Defaults to False. + + See Also: + `Apache Avro Schema Declaration `_ + + `Apache Avro Schema Resolution `_ + """ + + __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', + '_schema', '_parsed_schemas'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + async def __init__( + self, + schema_registry_client: AsyncSchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + return_record_name: bool = False, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + schema = None + if schema_str is not None: + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + raise TypeError('You must pass either schema string or schema object') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema = await self._get_parsed_schema(self._schema) + else: + self._reader_schema = None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature " + "from_dict(SerializationContext, dict) -> object") + self._from_dict = from_dict + + self._return_record_name = return_record_name + if not isinstance(self._return_record_name, bool): + raise ValueError("return_record_name must be a boolean value") + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + return self.__deserialize(data, ctx) + + async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + """ + Deserialize Avro binary encoded data with Confluent Schema Registry framing to + a dict, or object instance according to from_dict, if specified. + + Arguments: + data (bytes): bytes + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError: if an error occurs parsing data. + + Returns: + object: If data is None, then None. Else, a dict, or object instance according + to from_dict, if specified. + """ # noqa: E501 + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + writer_schema_raw = await self._registry.get_schema(schema_id) + writer_schema = await self._get_parsed_schema(writer_schema_raw) + + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("name")) + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema = self._reader_schema + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if migrations: + obj_dict = schemaless_reader(payload, + writer_schema, + None, + self._return_record_name) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + else: + obj_dict = schemaless_reader(payload, + writer_schema, + reader_schema, + self._return_record_name) + + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = await _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py new file mode 100644 index 000000000..c12e5869c --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import struct +from typing import Union, Optional, Tuple, Callable + +from cachetools import LRUCache +from jsonschema import ValidationError +from jsonschema.protocols import Validator +from jsonschema.validators import validator_for +from referencing import Registry, Resource + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, AsyncSchemaRegistryClient) +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.common.json_schema import ( + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform +) +from confluent_kafka.schema_registry.common import _ContextStringIO +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache +from confluent_kafka.serialization import (SerializationError, + SerializationContext) + + +async def _resolve_named_schema( + schema: Schema, schema_registry_client: AsyncSchemaRegistryClient, + ref_registry: Optional[Registry] = None +) -> Registry: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param ref_registry: Registry of named schemas resolved recursively. + :return: Registry + """ + if ref_registry is None: + # Retrieve external schemas for backward compatibility + ref_registry = Registry(retrieve=_retrieve_via_httpx) + if schema.references is not None: + for ref in schema.references: + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = await _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) + resource = Resource.from_contents( + referenced_schema_dict, default_specification=DEFAULT_SPEC) + ref_registry = ref_registry.with_resource(ref.name, resource) + return ref_registry + + +@asyncinit +class AsyncJSONSerializer(AsyncBaseSerializer): + """ + Serializer that outputs JSON encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Notes: + The ``title`` annotation, referred to elsewhere as a record name + is not strictly required by the JSON Schema specification. It is + however required by this serializer in order to register the schema + with Confluent Schema Registry. + + Prior to serialization, all objects must first be converted to + a dict instance. This may be handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with JSONSerializer. + + Args: + schema_str (str, Schema): + `JSON Schema definition. `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For + referencing other schemas, use a :py:class:`Schema` instance. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. + Converts object to a dict. + + conf (dict): JsonSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', + '_schema', '_schema_id', '_schema_name', '_to_dict', + '_parsed_schemas', '_validators', '_validate', '_json_encode'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + async def __init__( + self, + schema_str: Union[str, Schema, None], + schema_registry_client: AsyncSchemaRegistryClient, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_encode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + self._schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + self._schema = schema_str + else: + self._schema = None + + self._json_encode = json_encode or json.dumps + self._registry = schema_registry_client + self._rule_registry = ( + rule_registry if rule_registry else RuleRegistry.get_global_instance() + ) + self._schema_id = None + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + schema_dict, ref_registry = await self._get_parsed_schema(self._schema) + if schema_dict: + schema_name = schema_dict.get('title', None) + else: + schema_name = None + + self._schema_name = schema_name + self._parsed_schema = schema_dict + self._ref_registry = ref_registry + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to JSON, prepending it with Confluent Schema Registry + framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError if any error occurs serializing obj. + + Returns: + bytes: None if obj is None, else a byte array containing the JSON + serialized data with Confluent Schema Registry framing. + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = await self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = await self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + registered_schema = await self._registry.lookup_schema(subject, + self._schema, + self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + schema = latest_schema.schema + parsed_schema, ref_registry = await self._get_parsed_schema(latest_schema.schema) + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) + else: + schema = self._schema + parsed_schema, ref_registry = self._parsed_schema, self._ref_registry + + if self._validate: + try: + validator = self._get_validator(schema, parsed_schema, ref_registry) + validator.validate(value) + except ValidationError as ve: + raise SerializationError(ve.message) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) + # JSON dump always writes a str never bytes + # https://docs.python.org/3/library/json.html + encoded_value = self._json_encode(value) + if isinstance(encoded_value, str): + encoded_value = encoded_value.encode("utf8") + fo.write(encoded_value) + + return fo.getvalue() + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = await _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator + +@asyncinit +class AsyncJSONDeserializer(AsyncBaseDeserializer): + """ + Deserializer for JSON encoded data with Confluent Schema Registry + framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Args: + schema_str (str, Schema, optional): + `JSON schema definition `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. If not provided, schemas will be + retrieved from schema_registry_client based on the schema ID in the + wire header of each message. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to a Python object instance. + + schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. + """ # noqa: E501 + + __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', + '_parsed_schemas', '_validators', '_validate', '_json_decode'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + async def __init__( + self, + schema_str: Union[str, Schema, None], + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + schema_registry_client: Optional[AsyncSchemaRegistryClient] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_decode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + schema = schema_str + if bool(schema.references) and schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") + elif schema_str is None: + if schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is not provided""" + ) + schema = schema_str + else: + raise TypeError('You must pass either str or Schema') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + self._json_decode = json_decode or json.loads + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._validate, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema, self._ref_registry = await self._get_parsed_schema(self._schema) + else: + self._reader_schema, self._ref_registry = None, None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature" + " from_dict(dict, SerializationContext) -> object") + + self._from_dict = from_dict + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a JSON encoded record with Confluent Schema Registry framing to + a dict, or object instance according to from_dict if from_dict is specified. + + Args: + data (bytes): A JSON serialized record with Confluent Schema Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization operation. + + Returns: + A dict, or object instance according to from_dict if from_dict is specified. + + Raises: + SerializerError: If there was an error reading the Confluent framing data, or + if ``data`` was not successfully validated with the configured schema. + """ + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = await self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + + if self._registry is not None: + writer_schema_raw = await self._registry.get_schema(schema_id) + writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("title")) + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + else: + writer_schema_raw = None + writer_schema, writer_ref_registry = None, None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema, reader_ref_registry = writer_schema, writer_ref_registry + + if migrations: + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) + + if self._validate: + try: + validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) + validator.validate(obj_dict) + except ValidationError as ve: + raise SerializationError(ve.message) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = await _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py new file mode 100644 index 000000000..f6e850dea --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020-2022 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import struct +import warnings +from typing import Set, List, Union, Optional, Tuple + +from google.protobuf import json_format, descriptor_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.descriptor import Descriptor, FileDescriptor +from google.protobuf.message import DecodeError, Message +from google.protobuf.message_factory import GetMessageClass + +from confluent_kafka.schema_registry.common import (_MAGIC_BYTE, _ContextStringIO, + reference_subject_name_strategy, + topic_subject_name_strategy) +from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry import (Schema, + SchemaReference, + RuleMode) +from confluent_kafka.serialization import SerializationError, \ + SerializationContext +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache + + +async def _resolve_named_schema( + schema: Schema, + schema_registry_client: AsyncSchemaRegistryClient, + pool: DescriptorPool, + visited: Optional[Set[str]] = None +): + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: AsyncSchemaRegistryClient to use for retrieval. + :param pool: DescriptorPool to add resolved schemas to. + :return: DescriptorPool + """ + if visited is None: + visited = set() + if schema.references is not None: + for ref in schema.references: + if _is_builtin(ref.name) or ref.name in visited: + continue + visited.add(ref.name) + referenced_schema = await schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + await _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + pool.Add(file_descriptor_proto) + + +@asyncinit +class AsyncProtobufSerializer(AsyncBaseSerializer): + """ + Serializer for Protobuf Message derived classes. Serialization format is Protobuf, + with Confluent Schema Registry framing. + + Configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +=====================================+==========+======================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether or not to skip known types when resolving | + | ``skip.known.types`` | bool | schema dependencies. | + | | | | + | | | Defaults to True. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | + | | | references are constructed. | + | | | | + | | | Defaults to reference_subject_name_strategy | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | + | | | serialize message indexes without zig-zag encoding. | + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If the consumers of the topic being produced to are | + | | | using confluent-kafka-python <1.8 then this property | + | | | must be set to True until all old consumers have | + | | | have been upgraded. | + | | | | + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Args: + msg_type (Message): Protobuf Message type. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + conf (dict): ProtobufSerializer configuration. + + See Also: + `Protobuf API reference `_ + """ # noqa: E501 + __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', + '_schema', '_schema_id', '_ref_reference_subject_func', + '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'skip.known.types': True, + 'subject.name.strategy': topic_subject_name_strategy, + 'reference.subject.name.strategy': reference_subject_name_strategy, + 'use.deprecated.format': False, + } + + async def __init__( + self, + msg_type: Message, + schema_registry_client: AsyncSchemaRegistryClient, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufSerializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._skip_known_types = conf_copy.pop('skip.known.types') + if not isinstance(self._skip_known_types, bool): + raise ValueError("skip.known.types must be a boolean value") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2021 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._ref_reference_subject_func = conf_copy.pop( + 'reference.subject.name.strategy') + if not callable(self._ref_reference_subject_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._schema_id = None + self._known_subjects = set() + self._msg_class = msg_type + self._parsed_schemas = ParsedSchemaCache() + + descriptor = msg_type.DESCRIPTOR + self._index_array = _create_index_array(descriptor) + self._schema = Schema(_schema_to_str(descriptor.file), + schema_type='PROTOBUF') + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): + """ + Writes val to buf, either using zigzag or uvarint encoding. + + Args: + buf (BytesIO): buffer to write to. + val (int): integer to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + if zigzag: + val = (val << 1) ^ (val >> 63) + + while (val & ~0x7f) != 0: + buf.write(_bytes((val & 0x7f) | 0x80)) + val >>= 7 + buf.write(_bytes(val)) + + @staticmethod + def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): + """ + Encodes each int as a uvarint onto buf + + Args: + buf (BytesIO): buffer to write to. + ints ([int]): ints to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + assert len(ints) > 0 + # The root element at the 0 position does not need a length prefix. + if ints == [0]: + buf.write(_bytes(0x00)) + return + + AsyncProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) + + for value in ints: + AsyncProtobufSerializer._write_varint(buf, value, zigzag=zigzag) + + async def _resolve_dependencies( + self, ctx: SerializationContext, + file_desc: FileDescriptor + ) -> List[SchemaReference]: + """ + Resolves and optionally registers schema references recursively. + + Args: + ctx (SerializationContext): Serialization context. + + file_desc (FileDescriptor): file descriptor to traverse. + """ + + schema_refs = [] + for dep in file_desc.dependencies: + if self._skip_known_types and _is_builtin(dep.name): + continue + dep_refs = await self._resolve_dependencies(ctx, dep) + subject = self._ref_reference_subject_func(ctx, dep) + schema = Schema(_schema_to_str(dep), + references=dep_refs, + schema_type='PROTOBUF') + if self._auto_register: + await self._registry.register_schema(subject, schema) + + reference = await self._registry.lookup_schema(subject, schema) + # schema_refs are per file descriptor + schema_refs.append(SchemaReference(dep.name, + subject, + reference.version)) + return schema_refs + + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(message, ctx) + + async def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an instance of a class derived from Protobuf Message, and prepends + it with Confluent Schema Registry framing. + + Args: + message (Message): An instance of a class derived from Protobuf Message. + + ctx (SerializationContext): Metadata relevant to the serialization. + operation. + + Raises: + SerializerError if any error occurs during serialization. + + Returns: + None if messages is None, else a byte array containing the Protobuf + serialized message with Confluent Schema Registry framing. + """ + + if message is None: + return None + + if not isinstance(message, self._msg_class): + raise ValueError("message must be of type {} not {}" + .format(self._msg_class, type(message))) + + subject = self._subject_name_func(ctx, + message.DESCRIPTOR.full_name) + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + references = await self._resolve_dependencies( + ctx, message.DESCRIPTOR.file) + self._schema = Schema( + self._schema.schema_str, + self._schema.schema_type, + references + ) + + if self._auto_register: + self._schema_id = await self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + self._schema_id = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas).schema_id + + self._known_subjects.add(subject) + + if latest_schema is not None: + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + fd = pool.FindFileByName(fd_proto.name) + desc = fd.message_types_by_name[message.DESCRIPTOR.name] + field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + transform(rule_ctx, desc, msg, field_transform)) + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order + # (big endian) + fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the index array that specifies the message descriptor + # of the serialized data. + self._encode_varints(fo, self._index_array, + zigzag=not self._use_deprecated_format) + # write the serialized data itself + fo.write(message.SerializeToString()) + return fo.getvalue() + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + await _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + +@asyncinit +class AsyncProtobufDeserializer(AsyncBaseDeserializer): + """ + Deserializer for Protobuf serialized data with Confluent Schema Registry framing. + + Args: + message_type (Message derived type): Protobuf Message type. + conf (dict): Configuration dictionary. + + ProtobufDeserializer configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka. schema_registry | + | | | namespace . | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | + | | | deserialize message indexes without zig-zag encoding.| + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If Protobuf messages in the topic to consume were | + | | | produced with confluent-kafka-python <1.8 then this | + | | | property must be set to True until all old messages | + | | | have been processed and producers have been upgraded.| + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + + See Also: + `Protobuf API reference `_ + """ + + __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'use.deprecated.format': False, + } + + async def __init__( + self, + message_type: Message, + conf: Optional[dict] = None, + schema_registry_client: Optional[AsyncSchemaRegistryClient] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + # Require use.deprecated.format to be explicitly configured + # during a transitionary period since old/new format are + # incompatible. + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufDeserializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2022 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + descriptor = message_type.DESCRIPTOR + self._msg_class = GetMessageClass(descriptor) + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: + """ + Decodes a single varint from a buffer. + + Args: + buf (BytesIO): buffer to read from + zigzag (bool): decode as zigzag or uvarint + + Returns: + int: decoded varint + + Raises: + EOFError: if buffer is empty + """ + + value = 0 + shift = 0 + try: + while True: + i = AsyncProtobufDeserializer._read_byte(buf) + + value |= (i & 0x7f) << shift + shift += 7 + if not (i & 0x80): + break + + if zigzag: + value = (value >> 1) ^ -(value & 1) + + return value + + except EOFError: + raise EOFError("Unexpected EOF while reading index") + + @staticmethod + def _read_byte(buf: io.BytesIO) -> int: + """ + Read one byte from buf as an int. + + Args: + buf (BytesIO): The buffer to read from. + + .. _ord: + https://docs.python.org/2/library/functions.html#ord + """ + + i = buf.read(1) + if i == b'': + raise EOFError("Unexpected EOF encountered") + return ord(i) + + @staticmethod + def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: + """ + Read an index array from buf that specifies the message + descriptor of interest in the file descriptor. + + Args: + buf (BytesIO): The buffer to read from. + + Returns: + list of int: The index array. + """ + + size = AsyncProtobufDeserializer._decode_varint(buf, zigzag=zigzag) + if size < 0 or size > 100000: + raise DecodeError("Invalid Protobuf msgidx array length") + + if size == 0: + return [0] + + msg_index = [] + for _ in range(size): + msg_index.append(AsyncProtobufDeserializer._decode_varint(buf, + zigzag=zigzag)) + + return msg_index + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a serialized protobuf message with Confluent Schema Registry + framing. + + Args: + data (bytes): Serialized protobuf message with Confluent Schema + Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Returns: + Message: Protobuf Message instance. + + Raises: + SerializerError: If there was an error reading the Confluent framing + data, or parsing the protobuf serialized message. + """ + + if data is None: + return None + + # SR wire protocol + msg_index length + if len(data) < 6: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unknown magic byte. This message was " + "not produced with a Confluent " + "Schema Registry serializer") + + msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) + + if self._registry is not None: + writer_schema_raw = await self._registry.get_schema(schema_id, fmt='serialized') + fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) + writer_schema = pool.FindFileByName(fd_proto.name) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + if subject is None: + subject = self._subject_name_func(ctx, writer_desc.full_name) + if subject is not None: + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + else: + writer_schema_raw = None + writer_schema = None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + reader_schema = pool.FindFileByName(fd_proto.name) + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if reader_schema is not None: + # Initialize reader desc to first message in file + reader_desc = self._get_message_desc(pool, reader_schema, [0]) + # Attempt to find a reader desc with the same name as the writer + reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) + + if migrations: + msg = GetMessageClass(writer_desc)() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + obj_dict = json_format.MessageToDict(msg, True) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + msg = GetMessageClass(reader_desc)() + msg = json_format.ParseDict(obj_dict, msg) + else: + # Protobuf Messages are self-describing; no need to query schema + msg = self._msg_class() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_desc, message, field_transform)) + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) + + return msg + + async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + await _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + def _get_message_desc( + self, pool: DescriptorPool, fd: FileDescriptor, + msg_index: List[int] + ) -> Descriptor: + file_desc_proto = descriptor_pb2.FileDescriptorProto() + fd.CopyToProto(file_desc_proto) + (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) + package = file_desc_proto.package + qualified_name = package + "." + full_name if package else full_name + return pool.FindMessageTypeByName(qualified_name) + + def _get_message_desc_proto( + self, + path: str, + desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], + msg_index: List[int] + ) -> Tuple[str, descriptor_pb2.DescriptorProto]: + index = msg_index[0] + if isinstance(desc, descriptor_pb2.FileDescriptorProto): + msg = desc.message_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) + else: + msg = desc.nested_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py new file mode 100644 index 000000000..ea49e4bd0 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import json +import logging +import time +import urllib +from urllib.parse import unquote, urlparse + +import httpx +from typing import List, Dict, Optional, Union, Any, Tuple, Callable + +from cachetools import TTLCache, LRUCache +from httpx import Response + +from authlib.integrations.httpx_client import AsyncOAuth2Client + +from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError +from confluent_kafka.schema_registry.common.schema_registry_client import ( + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, + _BearerFieldProvider, + full_jitter, + _SchemaCache, + Schema, +) + +# TODO: consider adding `six` dependency or employing a compat file +# Python 2.7 is officially EOL so compatibility issue will be come more the norm. +# We need a better way to handle these issues. +# Six is one possibility but the compat file pattern used by requests +# is also quite nice. +# +# six: https://pypi.org/project/six/ +# compat file : https://github.com/psf/requests/blob/master/requests/compat.py +try: + string_type = basestring # noqa + + def _urlencode(value: str) -> str: + return urllib.quote(value, safe='') +except NameError: + string_type = str + + def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') + +log = logging.getLogger(__name__) + + +class _AsyncStaticFieldProvider(_BearerFieldProvider): + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + + async def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} + + +class _AsyncCustomOAuthClient(_BearerFieldProvider): + def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): + self.custom_function = custom_function + self.custom_config = custom_config + + async def get_bearer_fields(self) -> dict: + return await self.custom_function(self.custom_config) + + +class _AsyncOAuthClient(_BearerFieldProvider): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self.token = None + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + self.client = AsyncOAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint = token_endpoint + self.max_retries = max_retries + self.retries_wait_ms = retries_wait_ms + self.retries_max_wait_ms = retries_max_wait_ms + self.token_expiry_threshold = 0.8 + + async def get_bearer_fields(self) -> dict: + return { + 'bearer.auth.token': await self.get_access_token(), + 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool + } + + def token_expired(self) -> bool: + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window + + async def get_access_token(self) -> str: + if not self.token or self.token_expired(): + await self.generate_access_token() + + return self.token['access_token'] + + async def generate_access_token(self) -> None: + for i in range(self.max_retries + 1): + try: + self.token = await self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return + except Exception as e: + if i >= self.max_retries: + raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " + f"attempts due to error: {str(e)}") + await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + + +class _AsyncBaseRestClient(object): + + def __init__(self, conf: dict): + # copy dict to avoid mutating the original + conf_copy = conf.copy() + + base_url = conf_copy.pop('url', None) + if base_url is None: + raise ValueError("Missing required configuration property url") + if not isinstance(base_url, string_type): + raise TypeError("url must be a str, not " + str(type(base_url))) + base_urls = [] + for url in base_url.split(','): + url = url.strip().rstrip('/') + if not url.startswith('http') and not url.startswith('mock'): + raise ValueError("Invalid url {}".format(url)) + base_urls.append(url) + if not base_urls: + raise ValueError("Missing required configuration property url") + self.base_urls = base_urls + + self.verify = True + ca = conf_copy.pop('ssl.ca.location', None) + if ca is not None: + self.verify = ca + + key: Optional[str] = conf_copy.pop('ssl.key.location', None) + client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) + self.cert: Union[str, Tuple[str, str], None] = None + + if client_cert is not None and key is not None: + self.cert = (client_cert, key) + + if client_cert is not None and key is None: + self.cert = client_cert + + if key is not None and client_cert is None: + raise ValueError("ssl.certificate.location required when" + " configuring ssl.key.location") + + parsed = urlparse(self.base_urls[0]) + try: + userinfo = (unquote(parsed.username), unquote(parsed.password)) + except (AttributeError, TypeError): + userinfo = ("", "") + if 'basic.auth.user.info' in conf_copy: + if userinfo != ('', ''): + raise ValueError("basic.auth.user.info configured with" + " userinfo credentials in the URL." + " Remove userinfo credentials from the url or" + " remove basic.auth.user.info from the" + " configuration") + + userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) + + if len(userinfo) != 2: + raise ValueError("basic.auth.user.info must be in the form" + " of {username}:{password}") + + self.auth = userinfo if userinfo != ('', '') else None + + # The following adds support for proxy config + # If specified: it uses the specified proxy details when making requests + self.proxy = None + proxy = conf_copy.pop('proxy', None) + if proxy is not None: + self.proxy = proxy + + self.timeout = None + timeout = conf_copy.pop('timeout', None) + if timeout is not None: + self.timeout = timeout + + self.cache_capacity = 1000 + cache_capacity = conf_copy.pop('cache.capacity', None) + if cache_capacity is not None: + if not isinstance(cache_capacity, (int, float)): + raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) + self.cache_capacity = cache_capacity + + self.cache_latest_ttl_sec = None + cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) + if cache_latest_ttl_sec is not None: + if not isinstance(cache_latest_ttl_sec, (int, float)): + raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) + self.cache_latest_ttl_sec = cache_latest_ttl_sec + + self.max_retries = 3 + max_retries = conf_copy.pop('max.retries', None) + if max_retries is not None: + if not isinstance(max_retries, (int, float)): + raise TypeError("max.retries must be a number, not " + str(type(max_retries))) + self.max_retries = max_retries + + self.retries_wait_ms = 1000 + retries_wait_ms = conf_copy.pop('retries.wait.ms', None) + if retries_wait_ms is not None: + if not isinstance(retries_wait_ms, (int, float)): + raise TypeError("retries.wait.ms must be a number, not " + + str(type(retries_wait_ms))) + self.retries_wait_ms = retries_wait_ms + + self.retries_max_wait_ms = 20000 + retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) + if retries_max_wait_ms is not None: + if not isinstance(retries_max_wait_ms, (int, float)): + raise TypeError("retries.max.wait.ms must be a number, not " + + str(type(retries_max_wait_ms))) + self.retries_max_wait_ms = retries_max_wait_ms + + self.bearer_field_provider = None + logical_cluster = None + identity_pool = None + self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) + if self.bearer_auth_credentials_source is not None: + self.auth = None + + if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: + headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + missing_headers = [header for header in headers if header not in conf_copy] + if missing_headers: + raise ValueError("Missing required bearer configuration properties: {}" + .format(", ".join(missing_headers))) + + logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') + if not isinstance(logical_cluster, str): + raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) + + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') + if not isinstance(identity_pool, str): + raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _AsyncOAuthClient(self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _AsyncStaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + elif self.bearer_auth_credentials_source == 'CUSTOM': + custom_bearer_properties = ['bearer.auth.custom.provider.function', + 'bearer.auth.custom.provider.config'] + missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] + if missing_custom_properties: + raise ValueError("Missing required custom OAuth configuration properties: {}". + format(", ".join(missing_custom_properties))) + + custom_function = conf_copy.pop('bearer.auth.custom.provider.function') + if not callable(custom_function): + raise TypeError("bearer.auth.custom.provider.function must be a callable, not " + + str(type(custom_function))) + + custom_config = conf_copy.pop('bearer.auth.custom.provider.config') + if not isinstance(custom_config, dict): + raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + + str(type(custom_config))) + + self.bearer_field_provider = _AsyncCustomOAuthClient(custom_function, custom_config) + else: + raise ValueError('Unrecognized bearer.auth.credentials.source') + + # Any leftover keys are unknown to _RestClient + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + def get(self, url: str, query: Optional[dict] = None) -> Any: + raise NotImplementedError() + + def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + raise NotImplementedError() + + def delete(self, url: str) -> Any: + raise NotImplementedError() + + def put(self, url: str, body: Optional[dict] = None) -> Any: + raise NotImplementedError() + + +class _AsyncRestClient(_AsyncBaseRestClient): + """ + HTTP client for Confluent Schema Registry. + + See SchemaRegistryClient for configuration details. + + Args: + conf (dict): Dictionary containing _RestClient configuration + """ + + def __init__(self, conf: dict): + super().__init__(conf) + + self.session = httpx.AsyncClient( + verify=self.verify, + cert=self.cert, + auth=self.auth, + proxy=self.proxy, + timeout=self.timeout + ) + + async def handle_bearer_auth(self, headers: dict) -> None: + bearer_fields = await self.bearer_field_provider.get_bearer_fields() + required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + + missing_fields = [] + for field in required_fields: + if field not in bearer_fields: + missing_fields.append(field) + + if missing_fields: + raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" + .format(", ".join(missing_fields))) + + headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + + async def get(self, url: str, query: Optional[dict] = None) -> Any: + return await self.send_request(url, method='GET', query=query) + + async def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + return await self.send_request(url, method='POST', body=body) + + async def delete(self, url: str) -> Any: + return await self.send_request(url, method='DELETE') + + async def put(self, url: str, body: Optional[dict] = None) -> Any: + return await self.send_request(url, method='PUT', body=body) + + async def send_request( + self, url: str, method: str, body: Optional[dict] = None, + query: Optional[dict] = None + ) -> Any: + """ + Sends HTTP request to the SchemaRegistry, trying each base URL in turn. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + url (str): Request path + + method (str): HTTP method + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + dict: Schema Registry response content. + """ + + headers = {'Accept': "application/vnd.schemaregistry.v1+json," + " application/vnd.schemaregistry+json," + " application/json"} + + if body is not None: + body = json.dumps(body) + headers = {'Content-Length': str(len(body)), + 'Content-Type': "application/vnd.schemaregistry.v1+json"} + + if self.bearer_auth_credentials_source: + await self.handle_bearer_auth(headers) + + response = None + for i, base_url in enumerate(self.base_urls): + try: + response = await self.send_http_request( + base_url, url, method, headers, body, query) + + if is_success(response.status_code): + return response.json() + + if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: + break + except Exception as e: + if i == len(self.base_urls) - 1: + # Raise the exception since we have no more urls to try + raise e + + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + # Schema Registry may return malformed output when it hits unexpected errors + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + + async def send_http_request( + self, base_url: str, url: str, method: str, headers: Optional[dict], + body: Optional[str] = None, query: Optional[dict] = None + ) -> Response: + """ + Sends HTTP request to the SchemaRegistry. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + base_url (str): Schema Registry base URL + + url (str): Request path + + method (str): HTTP method + + headers (dict): Headers + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + Response: Schema Registry response content. + """ + response = None + for i in range(self.max_retries + 1): + response = await self.session.request( + method, url="/".join([base_url, url]), + headers=headers, content=body, params=query) + + if is_success(response.status_code): + return response + + if not is_retriable(response.status_code) or i >= self.max_retries: + return response + + await asyncio.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + return response + + +class AsyncSchemaRegistryClient(object): + """ + A Confluent Schema Registry client. + + Configuration properties (* indicates a required field): + + +------------------------------+------+-------------------------------------------------+ + | Property name | type | Description | + +==============================+======+=================================================+ + | ``url`` * | str | Comma-separated list of Schema Registry URLs. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to CA certificate file used | + | ``ssl.ca.location`` | str | to verify the Schema Registry's | + | | | private key. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's private key | + | | | (PEM) used for authentication. | + | ``ssl.key.location`` | str | | + | | | ``ssl.certificate.location`` must also be set. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's public key (PEM) used for | + | | | authentication. | + | ``ssl.certificate.location`` | str | | + | | | May be set without ssl.key.location if the | + | | | private key is stored within the PEM as well. | + +------------------------------+------+-------------------------------------------------+ + | | | Client HTTP credentials in the form of | + | | | ``username:password``. | + | ``basic.auth.user.info`` | str | | + | | | By default userinfo is extracted from | + | | | the URL if present. | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``proxy`` | str | Proxy such as http://localhost:8030. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``timeout`` | int | Request timeout. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | Maximum time to wait for the first retry. | + | | | When jitter is applied, the actual wait may | + | ``retries.wait.ms`` | int | be less. | + | | | | + | | | Defaults to 1000. | + +------------------------------+------+-------------------------------------------------+ + + Args: + conf (dict): Schema Registry client configuration. + + See Also: + `Confluent Schema Registry documentation `_ + """ # noqa: E501 + + def __init__(self, conf: dict): + self._conf = conf + self._rest_client = _AsyncRestClient(conf) + self._cache = _SchemaCache() + cache_capacity = self._rest_client.cache_capacity + cache_ttl = self._rest_client.cache_latest_ttl_sec + if cache_ttl is not None: + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + else: + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + if self._rest_client is not None: + await self._rest_client.session.aclose() + + def config(self): + return self._conf + + async def register_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> int: + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = await self.register_schema_full_response(subject_name, schema, normalize_schemas) + return registered_schema.schema_id + + async def register_schema_full_response( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> 'RegisteredSchema': + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + schema_id = self._cache.get_id_by_schema(subject_name, schema) + if schema_id is not None: + return RegisteredSchema(schema_id, schema, subject_name, None) + + request = schema.to_dict() + + response = await self._rest_client.post( + 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), + body=request) + + registered_schema = RegisteredSchema.from_dict(response) + + # The registered schema may not be fully populated + self._cache.set_schema(subject_name, registered_schema.schema_id, schema) + + return registered_schema + + async def get_schema( + self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``schema_id`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + schema_id (int): Schema id + subject_name (str): Subject name the schema is registered under + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``schema_id`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + schema = self._cache.get_schema_by_id(subject_name, schema_id) + if schema is not None: + return schema + + query = {'subject': subject_name} if subject_name is not None else None + if fmt is not None: + if query is not None: + query['format'] = fmt + else: + query = {'format': fmt} + response = await self._rest_client.get('schemas/ids/{}'.format(schema_id), query) + + schema = Schema.from_dict(response) + + self._cache.set_schema(subject_name, schema_id, schema) + + return schema + + async def lookup_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False, deleted: bool = False + ) -> 'RegisteredSchema': + """ + Returns ``schema`` registration information for ``subject``. + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + normalize_schemas (bool): Normalize schema before registering + deleted (bool): Whether to include deleted schemas. + + Returns: + RegisteredSchema: Subject registration information for this schema. + + Raises: + SchemaRegistryError: If schema or subject can't be found + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) + if registered_schema is not None: + return registered_schema + + request = schema.to_dict() + + response = await self._rest_client.post('subjects/{}?normalize={}&deleted={}' + .format(_urlencode(subject_name), normalize_schemas, deleted), + body=request) + + result = RegisteredSchema.from_dict(response) + + # Ensure the schema matches the input + registered_schema = RegisteredSchema( + schema_id=result.schema_id, + subject=result.subject, + version=result.version, + schema=schema, + ) + + self._cache.set_registered_schema(schema, registered_schema) + + return registered_schema + + async def get_subjects(self) -> List[str]: + """ + List all subjects registered with the Schema Registry + + Returns: + list(str): Registered subject names + + Raises: + SchemaRegistryError: if subjects can't be found + + See Also: + `GET subjects API Reference `_ + """ # noqa: E501 + + return await self._rest_client.get('subjects') + + async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: + """ + Deletes the specified subject and its associated compatibility level if + registered. It is recommended to use this API only when a topic needs + to be recycled or in development environments. + + Args: + subject_name (str): subject name + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + list(int): Versions deleted under this subject + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `DELETE Subject API Reference `_ + """ # noqa: E501 + + if permanent: + versions = await self._rest_client.delete('subjects/{}?permanent=true' + .format(_urlencode(subject_name))) + self._cache.remove_by_subject(subject_name) + else: + versions = await self._rest_client.delete('subjects/{}' + .format(_urlencode(subject_name))) + + return versions + + async def get_latest_version( + self, subject_name: str, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject + + Args: + subject_name (str): Subject name. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._latest_version_cache.get(subject_name, None) + if registered_schema is not None: + return registered_schema + + query = {'format': fmt} if fmt is not None else None + response = await self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + 'latest'), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_version_cache[subject_name] = registered_schema + + return registered_schema + + async def get_latest_with_metadata( + self, subject_name: str, metadata: Dict[str, str], + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject with the given metadata + + Args: + subject_name (str): Subject name. + metadata (dict): The key-value pairs for the metadata. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + """ # noqa: E501 + + cache_key = (subject_name, frozenset(metadata.items()), deleted) + registered_schema = self._latest_with_metadata_cache.get(cache_key, None) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + keys = metadata.keys() + if keys: + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] + + response = await self._rest_client.get('subjects/{}/metadata' + .format(_urlencode(subject_name)), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_with_metadata_cache[cache_key] = registered_schema + + return registered_schema + + async def get_version( + self, subject_name: str, version: int, + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves a specific schema registered under ``subject_name``. + + Args: + subject_name (str): Subject name. + version (int): version number. Defaults to latest version. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + response = await self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_registered_schema(registered_schema.schema, registered_schema) + + return registered_schema + + async def get_versions(self, subject_name: str) -> List[int]: + """ + Get a list of all versions registered with this subject. + + Args: + subject_name (str): Subject name. + + Returns: + list(int): Registered versions + + Raises: + SchemaRegistryError: If subject can't be found + + See Also: + `GET Subject Versions API Reference `_ + """ # noqa: E501 + + return await self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) + + async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: + """ + Deletes a specific version registered to ``subject_name``. + + Args: + subject_name (str) Subject name + + version (int): Version number + + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + int: Version number which was deleted + + Raises: + SchemaRegistryError: if the subject or version cannot be found. + + See Also: + `Delete Subject Version API Reference `_ + """ # noqa: E501 + + if permanent: + response = await self._rest_client.delete('subjects/{}/versions/{}?permanent=true' + .format(_urlencode(subject_name), + version)) + self._cache.remove_by_subject_version(subject_name, version) + else: + response = await self._rest_client.delete('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version)) + + return response + + async def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: + """ + Update global or subject level compatibility level. + + Args: + level (str): Compatibility level. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global compatibility + level policy if not set. + + Returns: + str: The newly configured compatibility level. + + Raises: + SchemaRegistryError: If the compatibility level is invalid. + + See Also: + `PUT Subject Compatibility API Reference `_ + """ # noqa: E501 + + if level is None: + raise ValueError("level must be set") + + if subject_name is None: + return await self._rest_client.put('config', + body={'compatibility': level.upper()}) + + return await self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body={'compatibility': level.upper()}) + + async def get_compatibility(self, subject_name: Optional[str] = None) -> str: + """ + Get the current compatibility level. + + Args: + subject_name (str, optional): Subject name. Returns global policy + if left unset. + + Returns: + str: Compatibility level for the subject if set, otherwise the global compatibility level. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Compatibility API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = await self._rest_client.get(url) + return result['compatibilityLevel'] + + async def test_compatibility( + self, subject_name: str, schema: 'Schema', + version: Union[int, str] = "latest" + ) -> bool: + """Test the compatibility of a candidate schema for a given subject and version + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + version (int or str, optional): Version number, or the string "latest". Defaults to "latest". + + Returns: + bool: True if the schema is compatible with the specified version + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `POST Test Compatibility API Reference `_ + """ # noqa: E501 + + request = schema.to_dict() + + response = await self._rest_client.post( + 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request + ) + + return response['is_compatible'] + + async def set_config( + self, subject_name: Optional[str] = None, + config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': + """ + Update global or subject config. + + Args: + config (ServerConfig): Config. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global config + if not set. + + Returns: + str: The newly configured config. + + Raises: + SchemaRegistryError: If the config is invalid. + + See Also: + `PUT Subject Config API Reference `_ + """ # noqa: E501 + + if config is None: + raise ValueError("config must be set") + + if subject_name is None: + return await self._rest_client.put('config', + body=config.to_dict()) + + return await self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body=config.to_dict()) + + async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': + """ + Get the current config. + + Args: + subject_name (str, optional): Subject name. Returns global config + if left unset. + + Returns: + ServerConfig: Config for the subject if set, otherwise the global config. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Config API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = await self._rest_client.get(url) + return ServerConfig.from_dict(result) + + def clear_latest_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + + def clear_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + self._cache.clear() + + @staticmethod + def new_client(conf: dict) -> 'AsyncSchemaRegistryClient': + from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient + url = conf.get("url") + if url.startswith("mock://"): + return MockSchemaRegistryClient(conf) + return AsyncSchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py new file mode 100644 index 000000000..677d916f7 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import List, Optional, Set, Dict, Any + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.serde import ErrorAction, FieldTransformer, Migration, NoneAction, RuleAction, RuleConditionError, RuleContext, RuleError +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema, RuleSet +from confluent_kafka.serialization import Serializer, Deserializer, \ + SerializationContext, SerializationError + +log = logging.getLogger(__name__) + +class AsyncBaseSerde(object): + __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', + '_registry', '_rule_registry', '_subject_name_func', + '_field_transformer'] + + async def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: + if self._use_schema_id is not None: + schema = await self._registry.get_schema(self._use_schema_id, subject, fmt) + return await self._registry.lookup_schema(subject, schema, False, True) + if self._use_latest_with_metadata is not None: + return await self._registry.get_latest_with_metadata( + subject, self._use_latest_with_metadata, True, fmt) + if self._use_latest_version: + return await self._registry.get_latest_version(subject, fmt) + return None + + def _execute_rules( + self, ser_ctx: SerializationContext, subject: str, + rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] + ) -> Any: + if message is None or target is None: + return message + rules: Optional[List[Rule]] = None + if rule_mode == RuleMode.UPGRADE: + if target is not None and target.rule_set is not None: + rules = target.rule_set.migration_rules + elif rule_mode == RuleMode.DOWNGRADE: + if source is not None and source.rule_set is not None: + rules = source.rule_set.migration_rules + rules = rules[:] if rules else [] + rules.reverse() + else: + if target is not None and target.rule_set is not None: + rules = target.rule_set.domain_rules + if rule_mode == RuleMode.READ: + # Execute read rules in reverse order for symmetry + rules = rules[:] if rules else [] + rules.reverse() + + if not rules: + return message + + for index in range(len(rules)): + rule = rules[index] + if self._is_disabled(rule): + continue + if rule.mode == RuleMode.WRITEREAD: + if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: + continue + elif rule.mode == RuleMode.UPDOWN: + if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: + continue + elif rule.mode != rule_mode: + continue + + ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, + index, rules, inline_tags, field_transformer) + rule_executor = self._rule_registry.get_executor(rule.type.upper()) + if rule_executor is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Could not find rule executor of type {rule.type}"), + 'ERROR') + return message + try: + result = rule_executor.transform(ctx, message) + if rule.kind == RuleKind.CONDITION: + if not result: + raise RuleConditionError(rule) + elif rule.kind == RuleKind.TRANSFORM: + message = result + self._run_action( + ctx, rule_mode, rule, + self._get_on_failure(rule) if message is None else self._get_on_success(rule), + message, None, + 'ERROR' if message is None else 'NONE') + except SerializationError: + raise + except Exception as e: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), + message, e, 'ERROR') + return message + + def _get_on_success(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_success is not None: + return override.on_success + return rule.on_success + + def _get_on_failure(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_failure is not None: + return override.on_failure + return rule.on_failure + + def _is_disabled(self, rule: Rule) -> Optional[bool]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.disabled is not None: + return override.disabled + return rule.disabled + + def _run_action( + self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, + action: Optional[str], message: Any, + ex: Optional[Exception], default_action: str + ): + action_name = self._get_rule_action_name(rule, rule_mode, action) + if action_name is None: + action_name = default_action + rule_action = self._get_rule_action(ctx, action_name) + if rule_action is None: + log.error("Could not find rule action of type %s", action_name) + raise RuleError(f"Could not find rule action of type {action_name}") + try: + rule_action.run(ctx, message, ex) + except SerializationError: + raise + except Exception as e: + log.warning("Could not run post-rule action %s: %s", action_name, e) + + def _get_rule_action_name( + self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] + ) -> Optional[str]: + if action_name is None or action_name == "": + return None + if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: + parts = action_name.split(',') + if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + return parts[0] + elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): + return parts[1] + return action_name + + def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: + if action_name == 'ERROR': + return ErrorAction() + elif action_name == 'NONE': + return NoneAction() + return self._rule_registry.get_action(action_name) + + +class AsyncBaseSerializer(AsyncBaseSerde, Serializer): + __slots__ = ['_auto_register', '_normalize_schemas'] + + +class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): + __slots__ = [] + + def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + if rule_set is None: + return False + if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): + return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN + for rule in rule_set.migration_rules or []) + elif mode == RuleMode.UPDOWN: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + elif mode in (RuleMode.WRITE, RuleMode.READ): + return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD + for rule in rule_set.domain_rules or []) + elif mode == RuleMode.WRITEREAD: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return False + + async def _get_migrations( + self, subject: str, source_info: Schema, + target: RegisteredSchema, fmt: Optional[str] + ) -> List[Migration]: + source = self._registry.lookup_schema(subject, source_info, False, True) + migrations = [] + if source.version < target.version: + migration_mode = RuleMode.UPGRADE + first = source + last = target + elif source.version > target.version: + migration_mode = RuleMode.DOWNGRADE + first = target + last = source + else: + return migrations + previous: Optional[RegisteredSchema] = None + versions = await self._get_schemas_between(subject, first, last, fmt) + for i in range(len(versions)): + version = versions[i] + if i == 0: + previous = version + continue + if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) + previous = version + if migration_mode == RuleMode.DOWNGRADE: + migrations.reverse() + return migrations + + async def _get_schemas_between( + self, subject: str, first: RegisteredSchema, + last: RegisteredSchema, fmt: Optional[str] = None + ) -> List[RegisteredSchema]: + if last.version - first.version <= 1: + return [first, last] + version1 = first.version + version2 = last.version + result = [first] + for i in range(version1 + 1, version2): + result.append(await self._registry.get_version(subject, i, True, fmt)) + result.append(last) + return result + + def _execute_migrations( + self, ser_ctx: SerializationContext, subject: str, + migrations: List[Migration], message: Any + ) -> Any: + for migration in migrations: + message = self._execute_rules(ser_ctx, subject, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) + return message diff --git a/tests/integration/cluster_fixture.py b/tests/integration/cluster_fixture.py index efaa55f70..0d441ca1d 100644 --- a/tests/integration/cluster_fixture.py +++ b/tests/integration/cluster_fixture.py @@ -24,9 +24,12 @@ from confluent_kafka import Producer, SerializingProducer from confluent_kafka.admin import AdminClient, NewTopic from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry._async.schema_registry_client import AsyncSchemaRegistryClient from tests.common import TestConsumer from tests.common.schema_registry import TestDeserializingConsumer +from tests.common._async.consumer import TestAsyncDeserializingConsumer +from tests.common._async.producer import TestAsyncSerializingProducer class KafkaClusterFixture(object): @@ -89,6 +92,32 @@ def producer(self, conf=None, key_serializer=None, value_serializer=None): return SerializingProducer(client_conf) + def async_producer(self, conf=None, key_serializer=None, value_serializer=None): + """ + Returns a producer bound to this cluster. + + Args: + conf (dict): Producer configuration overrides + + key_serializer (Serializer): serializer to apply to message key + + value_serializer (Deserializer): serializer to apply to + message value + + Returns: + Producer: A new SerializingProducer instance + + """ + client_conf = self.client_conf(conf) + + if key_serializer is not None: + client_conf['key.serializer'] = key_serializer + + if value_serializer is not None: + client_conf['value.serializer'] = value_serializer + + return TestAsyncSerializingProducer(client_conf) + def cimpl_consumer(self, conf=None): """ Returns a consumer bound to this cluster. @@ -143,6 +172,39 @@ def consumer(self, conf=None, key_deserializer=None, value_deserializer=None): return TestDeserializingConsumer(consumer_conf) + def async_consumer(self, conf=None, key_deserializer=None, value_deserializer=None): + """ + Returns a consumer bound to this cluster. + + Args: + conf (dict): Consumer config overrides + + key_deserializer (Deserializer): deserializer to apply to + message key + + value_deserializer (Deserializer): deserializer to apply to + message value + + Returns: + Consumer: A new DeserializingConsumer instance + + """ + consumer_conf = self.client_conf({ + 'group.id': str(uuid1()), + 'auto.offset.reset': 'earliest' + }) + + if conf is not None: + consumer_conf.update(conf) + + if key_deserializer is not None: + consumer_conf['key.deserializer'] = key_deserializer + + if value_deserializer is not None: + consumer_conf['value.deserializer'] = value_deserializer + + return TestAsyncDeserializingConsumer(consumer_conf) + def admin(self, conf=None): if conf: # When conf is passed create a new AdminClient @@ -286,6 +348,15 @@ def schema_registry(self, conf=None): sr_conf.update(conf) return SchemaRegistryClient(sr_conf) + def async_schema_registry(self, conf=None): + if not hasattr(self._cluster, 'sr'): + return None + + sr_conf = {'url': self._cluster.sr.get('url')} + if conf is not None: + sr_conf.update(conf) + return AsyncSchemaRegistryClient(sr_conf) + def client_conf(self, conf=None): """ Default client configuration diff --git a/tests/integration/schema_registry/_async/__init__.py b/tests/integration/schema_registry/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/schema_registry/_async/test_api_client.py b/tests/integration/schema_registry/_async/test_api_client.py new file mode 100644 index 000000000..244e1c4b1 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_api_client.py @@ -0,0 +1,491 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from uuid import uuid1 + +import pytest + +from confluent_kafka.schema_registry import Schema +from confluent_kafka.schema_registry.error import SchemaRegistryError +from tests.integration.conftest import kafka_cluster_fixture + + +@pytest.fixture(scope="module") +def kafka_cluster_cp_7_0_1(): + """ + Returns a Trivup cluster with CP version 7.0.1. + SR version 7.0.1 is the last returning 500 instead of 422 + for the invalid schema passed to test_api_get_register_schema_invalid + """ + for fixture in kafka_cluster_fixture( + brokers_env="BROKERS_7_0_1", + sr_url_env="SR_URL_7_0_1", + trivup_cluster_conf={'cp_version': '7.0.1'} + ): + yield fixture + + +def _subject_name(prefix): + return prefix + "-" + str(uuid1()) + + +async def test_api_register_schema(kafka_cluster, load_file): + """ + Registers a schema, verifies the registration + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + avsc = 'basic_schema.avsc' + subject = _subject_name(avsc) + schema = Schema(load_file(avsc), schema_type='AVRO') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + assert schema.schema_str, registered_schema.schema.schema_str + + +async def test_api_register_normalized_schema(kafka_cluster, load_file): + """ + Registers a schema, verifies the registration + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + avsc = 'basic_schema.avsc' + subject = _subject_name(avsc) + schema = Schema(load_file(avsc), schema_type='AVRO') + + schema_id = await sr.register_schema(subject, schema, True) + registered_schema = await sr.lookup_schema(subject, schema, True) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + assert schema.schema_str, registered_schema.schema.schema_str + + +async def test_api_register_schema_incompatible(kafka_cluster, load_file): + """ + Attempts to register an incompatible Schema verifies the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema1 = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + schema2 = Schema(load_file('adv_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_register_incompatible') + + await sr.register_schema(subject, schema1) + + with pytest.raises(SchemaRegistryError, match="Schema being registered is" + " incompatible with an" + " earlier schema") as e: + # The default Schema Registry compatibility type is BACKWARD. + # this allows 1) fields to be deleted, 2) optional fields to + # be added. schema2 adds non-optional fields to schema1, so + # registering schema2 after schema1 should fail. + await sr.register_schema(subject, schema2) + assert e.value.http_status_code == 409 # conflict + assert e.value.error_code == 409 + + +async def test_api_register_schema_invalid(kafka_cluster, load_file): + """ + Attempts to register an invalid schema, validates the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('invalid_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_invalid_schema') + + with pytest.raises(SchemaRegistryError) as e: + await sr.register_schema(subject, schema) + assert e.value.http_status_code == 422 + assert e.value.error_code == 42201 + + +async def test_api_get_schema(kafka_cluster, load_file): + """ + Registers a schema then retrieves it using the schema id returned by the + call to register the Schema. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('get_schema') + + schema_id = await sr.register_schema(subject, schema) + registration = await sr.lookup_schema(subject, schema) + + assert registration.schema_id == schema_id + assert registration.subject == subject + assert schema.schema_str, registration.schema.schema_str + + +async def test_api_get_schema_not_found(kafka_cluster, load_file): + """ + Attempts to fetch an unknown schema by id, validates the error. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + with pytest.raises(SchemaRegistryError, match="Schema .*not found.*") as e: + await sr.get_schema(999999) + + assert e.value.http_status_code == 404 + assert e.value.error_code == 40403 + + +async def test_api_get_registration_subject_not_found(kafka_cluster, load_file): + """ + Attempts to obtain information about a schema's subject registration for + an unknown subject. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + + subject = _subject_name("registration_subject_not_found") + + with pytest.raises(SchemaRegistryError, match="Subject .*not found.*") as e: + await sr.lookup_schema(subject, schema) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40401 + + +@pytest.mark.parametrize("kafka_cluster_name, http_status_code, error_code", [ + ["kafka_cluster_cp_7_0_1", 500, 500], + ["kafka_cluster", 422, 42201], +]) +async def test_api_get_register_schema_invalid( + kafka_cluster_name, + http_status_code, + error_code, + load_file, + request): + """ + Attempts to obtain registration information with an invalid schema + with different CP versions. + + Args: + kafka_cluster_name (str): name of the Kafka Cluster fixture to use + http_status_code (int): HTTP status return code expected in this version + error_code (int): error code expected in this version + load_file (callable(str)): Schema fixture constructor + request (FixtureRequest): PyTest object giving access to the test context + """ + kafka_cluster = request.getfixturevalue(kafka_cluster_name) + sr = kafka_cluster.async_schema_registry() + subject = _subject_name("registration_invalid_schema") + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + + # register valid schema so we don't hit subject not found exception + await sr.register_schema(subject, schema) + schema2 = Schema(load_file('invalid_schema.avsc'), schema_type='AVRO') + + with pytest.raises(SchemaRegistryError, match="Invalid schema") as e: + await sr.lookup_schema(subject, schema2) + + assert e.value.http_status_code == http_status_code + assert e.value.error_code == error_code + + +async def test_api_get_subjects(kafka_cluster, load_file): + """ + Populates KafkaClusterFixture SR instance with a fixed number of subjects + then verifies the response includes them all. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + avscs = ['basic_schema.avsc', 'primitive_string.avsc', + 'primitive_bool.avsc', 'primitive_float.avsc'] + + subjects = [] + for avsc in avscs: + schema = Schema(load_file(avsc), schema_type='AVRO') + subject = _subject_name(avsc) + subjects.append(subject) + + await sr.register_schema(subject, schema) + + registered = await sr.get_subjects() + + assert all([s in registered for s in subjects]) + + +async def test_api_get_subject_versions(kafka_cluster, load_file): + """ + Registers a Schema with a subject, lists the versions associated with that + subject and ensures the versions and their schemas match what was + registered. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor. + + """ + sr = kafka_cluster.async_schema_registry() + + subject = _subject_name("list-version-test") + await sr.set_compatibility(level="NONE") + + avscs = ['basic_schema.avsc', 'primitive_string.avsc', + 'primitive_bool.avsc', 'primitive_float.avsc'] + + schemas = [] + for avsc in avscs: + schema = Schema(load_file(avsc), schema_type='AVRO') + schemas.append(schema) + await sr.register_schema(subject, schema) + + versions = await sr.get_versions(subject) + assert len(versions) == len(avscs) + for schema in schemas: + registered_schema = await sr.lookup_schema(subject, schema) + assert registered_schema.subject == subject + assert registered_schema.version in versions + + # revert global compatibility level back to the default. + await sr.set_compatibility(level="BACKWARD") + + +async def test_api_delete_subject(kafka_cluster, load_file): + """ + Registers a Schema under a specific subject then deletes it. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name("test-delete") + + await sr.register_schema(subject, schema) + assert subject in await sr.get_subjects() + + await sr.delete_subject(subject) + assert subject not in await sr.get_subjects() + + +async def test_api_delete_subject_not_found(kafka_cluster): + sr = kafka_cluster.async_schema_registry() + + subject = _subject_name("test-delete_invalid_subject") + + with pytest.raises(SchemaRegistryError, match="Subject .*not found.*") as e: + await sr.delete_subject(subject) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40401 + + +async def test_api_get_subject_version(kafka_cluster, load_file): + """ + Registers a schema, fetches that schema by it's subject version id. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + + await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + registered_schema2 = await sr.get_version(subject, registered_schema.version) + + assert registered_schema2.schema_id == registered_schema.schema_id + assert registered_schema2.schema.schema_str == registered_schema.schema.schema_str + assert registered_schema2.version == registered_schema.version + + +async def test_api_get_subject_version_no_version(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + # ensures subject exists and has a single version + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + await sr.register_schema(subject, schema) + + with pytest.raises(SchemaRegistryError, match="Version .*not found") as e: + await sr.get_version(subject, version=3) + assert e.value.http_status_code == 404 + assert e.value.error_code == 40402 + + +async def test_api_get_subject_version_invalid(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + # ensures subject exists and has a single version + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test-get_subject') + await sr.register_schema(subject, schema) + + with pytest.raises(SchemaRegistryError, + match="The specified version .*is not" + " a valid version id.*") as e: + await sr.get_version(subject, version='a') + assert e.value.http_status_code == 422 + assert e.value.error_code == 42202 + + +async def test_api_post_subject_registration(kafka_cluster, load_file): + """ + Registers a schema, fetches that schema by it's subject version id. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = _subject_name('test_registration') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject + + +async def test_api_delete_subject_version(kafka_cluster, load_file): + """ + Registers a Schema under a specific subject then deletes it. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = str(uuid1()) + + await sr.register_schema(subject, schema) + await sr.delete_version(subject, 1) + + assert subject not in await sr.get_subjects() + + +async def test_api_subject_config_update(kafka_cluster, load_file): + """ + Updates a subjects compatibility policy then ensures the same policy + is returned when queried. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + load_file (callable(str)): Schema fixture constructor + + """ + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('basic_schema.avsc'), schema_type='AVRO') + subject = str(uuid1()) + + await sr.register_schema(subject, schema) + await sr.set_compatibility(subject_name=subject, + level="FULL_TRANSITIVE") + + assert await sr.get_compatibility(subject_name=subject) == "FULL_TRANSITIVE" + + +async def test_api_config_invalid(kafka_cluster): + """ + Sets an invalid compatibility level, validates the exception. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + """ + sr = kafka_cluster.async_schema_registry() + + with pytest.raises(SchemaRegistryError, match="Invalid compatibility" + " level") as e: + await sr.set_compatibility(level="INVALID") + e.value.http_status_code = 422 + e.value.error_code = 42203 + + +async def test_api_config_update(kafka_cluster): + """ + Updates the global compatibility policy then ensures the same policy + is returned when queried. + + Args: + kafka_cluster (KafkaClusterFixture): Kafka Cluster fixture + """ + sr = kafka_cluster.async_schema_registry() + + for level in ["BACKWARD", "BACKWARD_TRANSITIVE", "FORWARD", "FORWARD_TRANSITIVE"]: + await sr.set_compatibility(level=level) + assert await sr.get_compatibility() == level + + # revert global compatibility level back to the default. + await sr.set_compatibility(level="BACKWARD") + + +async def test_api_register_logical_schema(kafka_cluster, load_file): + sr = kafka_cluster.async_schema_registry() + + schema = Schema(load_file('logical_date.avsc'), schema_type='AVRO') + subject = _subject_name('test_logical_registration') + + schema_id = await sr.register_schema(subject, schema) + registered_schema = await sr.lookup_schema(subject, schema) + + assert registered_schema.schema_id == schema_id + assert registered_schema.subject == subject diff --git a/tests/integration/schema_registry/_async/test_avro_serializers.py b/tests/integration/schema_registry/_async/test_avro_serializers.py new file mode 100644 index 000000000..966bbc14d --- /dev/null +++ b/tests/integration/schema_registry/_async/test_avro_serializers.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +from confluent_kafka import TopicPartition +from confluent_kafka.serialization import (MessageField, + SerializationContext) +from confluent_kafka.schema_registry.avro import (AsyncAvroSerializer, + AsyncAvroDeserializer) +from confluent_kafka.schema_registry import Schema, SchemaReference + + +class User(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "User", + "type": "record", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_number", "type": "int"}, + {"name": "favorite_color", "type": "string"} + ] + } + """ + + def __init__(self, name, favorite_number, favorite_color): + self.name = name + self.favorite_number = favorite_number + self.favorite_color = favorite_color + + def __eq__(self, other): + return all([ + self.name == other.name, + self.favorite_number == other.favorite_number, + self.favorite_color == other.favorite_color]) + + +class AwardProperties(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "AwardProperties", + "type": "record", + "fields": [ + {"name": "year", "type": "int"}, + {"name": "points", "type": "int"} + ] + } + """ + + def __init__(self, points, year): + self.points = points + self.year = year + + def __eq__(self, other): + return all([ + self.points == other.points, + self.year == other.year + ]) + + +class Award(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "Award", + "type": "record", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "properties", "type": "AwardProperties"} + ] + } + """ + + def __init__(self, name, properties): + self.name = name + self.properties = properties + + def __eq__(self, other): + return all([ + self.name == other.name, + self.properties == other.properties + ]) + + +class AwardedUser(object): + schema_str = """ + { + "namespace": "confluent.io.examples.serialization.avro", + "name": "AwardedUser", + "type": "record", + "fields": [ + {"name": "award", "type": "Award"}, + {"name": "user", "type": "User"} + ] + } + """ + + def __init__(self, award, user): + self.award = award + self.user = user + + def __eq__(self, other): + return all([ + self.award == other.award, + self.user == other.user + ]) + + +async def _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster): + sr = kafka_cluster.async_schema_registry() + + user = User('Bowie', 47, 'purple') + award_properties = AwardProperties(10, 2023) + award = Award("Best In Show", award_properties) + awarded_user = AwardedUser(award, user) + + user_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.User", "user", 1) + award_properties_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.AwardProperties", + "award_properties", 1) + award_schema_ref = SchemaReference("confluent.io.examples.serialization.avro.Award", "award", 1) + + await sr.register_schema("user", Schema(User.schema_str, 'AVRO')) + await sr.register_schema("award_properties", Schema(AwardProperties.schema_str, 'AVRO')) + await sr.register_schema("award", Schema(Award.schema_str, 'AVRO', [award_properties_schema_ref])) + + references = [user_schema_ref, award_schema_ref] + schema = Schema(AwardedUser.schema_str, 'AVRO', references) + return awarded_user, schema + + +async def _references_test_common(kafka_cluster, awarded_user, serializer_schema, deserializer_schema): + """ + Common (both reader and writer) avro schema reference test. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + topic = kafka_cluster.create_topic_and_wait_propogation("reference-avro") + sr = kafka_cluster.async_schema_registry() + + value_serializer = await AsyncAvroSerializer(sr, serializer_schema, + lambda user, ctx: + dict(award=dict(name=user.award.name, + properties=dict(year=user.award.properties.year, + points=user.award.properties.points)), + user=dict(name=user.user.name, + favorite_number=user.user.favorite_number, + favorite_color=user.user.favorite_color))) + + value_deserializer = \ + await AsyncAvroDeserializer(sr, deserializer_schema, + lambda user, ctx: + AwardedUser(award=Award(name=user.get('award').get('name'), + properties=AwardProperties( + year=user.get('award').get('properties').get( + 'year'), + points=user.get('award').get('properties').get( + 'points'))), + user=User(name=user.get('user').get('name'), + favorite_number=user.get('user').get('favorite_number'), + favorite_color=user.get('user').get('favorite_color')))) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=awarded_user, partition=0) + + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + awarded_user2 = msg.value() + + assert awarded_user2 == awarded_user + + +@pytest.mark.parametrize("avsc, data, record_type", + [('basic_schema.avsc', {'name': 'abc'}, "record"), + ('primitive_string.avsc', u'Jämtland', "string"), + ('primitive_bool.avsc', True, "bool"), + ('primitive_float.avsc', 32768.2342, "float"), + ('primitive_double.avsc', 68.032768, "float")]) +async def test_avro_record_serialization(kafka_cluster, load_file, avsc, data, record_type): + """ + Tests basic Avro serializer functionality + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + load_file (callable(str)): Avro file reader + avsc (str) avsc: Avro schema file + data (object): data to be serialized + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file(avsc) + value_serializer = await AsyncAvroSerializer(sr, schema_str) + + value_deserializer = await AsyncAvroDeserializer(sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=data, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + if record_type == 'record': + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + +@pytest.mark.parametrize("avsc, data,record_type", + [('basic_schema.avsc', dict(name='abc'), 'record'), + ('primitive_string.avsc', u'Jämtland', 'string'), + ('primitive_bool.avsc', True, 'bool'), + ('primitive_float.avsc', 768.2340, 'float'), + ('primitive_double.avsc', 6.868, 'float')]) +async def test_delivery_report_serialization(kafka_cluster, load_file, avsc, data, record_type): + """ + Tests basic Avro serializer functionality + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + load_file (callable(str)): Avro file reader + avsc (str) avsc: Avro schema file + data (object): data to be serialized + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro-dr") + sr = kafka_cluster.async_schema_registry() + schema_str = load_file(avsc) + + value_serializer = await AsyncAvroSerializer(sr, schema_str) + + value_deserializer = await AsyncAvroDeserializer(sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + async def assert_cb(err, msg): + actual = value_deserializer(msg.value(), + SerializationContext(topic, MessageField.VALUE, msg.headers())) + + if record_type == "record": + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + await producer.produce(topic, value=data, partition=0, on_delivery=assert_cb) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + # schema may include default which need not exist in the original + if record_type == 'record': + assert [v == actual[k] for k, v in data.items()] + elif record_type == 'float': + assert data == pytest.approx(actual) + else: + assert actual == data + + +async def test_avro_record_serialization_custom(kafka_cluster): + """ + Tests basic Avro serializer to_dict and from_dict object hook functionality. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-avro") + sr = kafka_cluster.async_schema_registry() + + user = User('Bowie', 47, 'purple') + value_serializer = await AsyncAvroSerializer(sr, User.schema_str, + lambda user, ctx: + dict(name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color)) + + value_deserializer = await AsyncAvroDeserializer(sr, User.schema_str, + lambda user_dict, ctx: + User(**user_dict)) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + await producer.produce(topic, value=user, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + user2 = msg.value() + + assert user2 == user + + +async def test_avro_reference(kafka_cluster): + """ + Tests Avro schema reference with both serializer and deserializer schemas provided. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + awarded_user, schema = await _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster) + + await _references_test_common(kafka_cluster, awarded_user, schema, schema) + + +async def test_avro_reference_deserializer_none(kafka_cluster): + """ + Tests Avro schema reference with serializer schema provided and deserializer schema set to None. + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + """ + awarded_user, schema = await _register_avro_schemas_and_build_awarded_user_schema(kafka_cluster) + + await _references_test_common(kafka_cluster, awarded_user, schema, None) diff --git a/tests/integration/schema_registry/_async/test_json_serializers.py b/tests/integration/schema_registry/_async/test_json_serializers.py new file mode 100644 index 000000000..8f45e0aaf --- /dev/null +++ b/tests/integration/schema_registry/_async/test_json_serializers.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +from confluent_kafka import TopicPartition + +from confluent_kafka.error import ConsumeError, ValueSerializationError +from confluent_kafka.schema_registry import SchemaReference, Schema, AsyncSchemaRegistryClient +from confluent_kafka.schema_registry.json_schema import (AsyncJSONSerializer, + AsyncJSONDeserializer) + + +class _TestProduct(object): + def __init__(self, product_id, name, price, tags, dimensions, location): + self.product_id = product_id + self.name = name + self.price = price + self.tags = tags + self.dimensions = dimensions + self.location = location + + def __eq__(self, other): + return all([ + self.product_id == other.product_id, + self.name == other.name, + self.price == other.price, + self.tags == other.tags, + self.dimensions == other.dimensions, + self.location == other.location + ]) + + +class _TestCustomer(object): + def __init__(self, name, id): + self.name = name + self.id = id + + def __eq__(self, other): + return all([ + self.name == other.name, + self.id == other.id + ]) + + +class _TestOrderDetails(object): + def __init__(self, id, customer): + self.id = id + self.customer = customer + + def __eq__(self, other): + return all([ + self.id == other.id, + self.customer == other.customer + ]) + + +class _TestOrder(object): + def __init__(self, order_details, product): + self.order_details = order_details + self.product = product + + def __eq__(self, other): + return all([ + self.order_details == other.order_details, + self.product == other.product + ]) + + +class _TestReferencedProduct(object): + def __init__(self, name, product): + self.name = name + self.product = product + + def __eq__(self, other): + return all([ + self.name == other.name, + self.product == other.product + ]) + + +def _testProduct_to_dict(product_obj, ctx): + """ + Returns testProduct instance in dict format. + + Args: + product_obj (_TestProduct): testProduct instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: product_obj as a dictionary. + + """ + return {"productId": product_obj.product_id, + "productName": product_obj.name, + "price": product_obj.price, + "tags": product_obj.tags, + "dimensions": product_obj.dimensions, + "warehouseLocation": product_obj.location} + + +def _testCustomer_to_dict(customer_obj, ctx): + """ + Returns testCustomer instance in dict format. + + Args: + customer_obj (_TestCustomer): testCustomer instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: customer_obj as a dictionary. + + """ + return {"name": customer_obj.name, + "id": customer_obj.id} + + +def _testOrderDetails_to_dict(orderdetails_obj, ctx): + """ + Returns testOrderDetails instance in dict format. + + Args: + orderdetails_obj (_TestOrderDetails): testOrderDetails instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: orderdetails_obj as a dictionary. + + """ + return {"id": orderdetails_obj.id, + "customer": _testCustomer_to_dict(orderdetails_obj.customer, ctx)} + + +def _testOrder_to_dict(order_obj, ctx): + """ + Returns testOrder instance in dict format. + + Args: + order_obj (_TestOrder): testOrder instance. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + dict: order_obj as a dictionary. + + """ + return {"order_details": _testOrderDetails_to_dict(order_obj.order_details, ctx), + "product": _testProduct_to_dict(order_obj.product, ctx)} + + +def _testProduct_from_dict(product_dict, ctx): + """ + Returns testProduct instance from its dict format. + + Args: + product_dict (dict): testProduct in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestProduct: product_obj instance. + + """ + return _TestProduct(product_dict['productId'], + product_dict['productName'], + product_dict['price'], + product_dict['tags'], + product_dict['dimensions'], + product_dict['warehouseLocation']) + + +def _testCustomer_from_dict(customer_dict, ctx): + """ + Returns testCustomer instance from its dict format. + + Args: + customer_dict (dict): testCustomer in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestCustomer: customer_obj instance. + + """ + return _TestCustomer(customer_dict['name'], + customer_dict['id']) + + +def _testOrderDetails_from_dict(orderdetails_dict, ctx): + """ + Returns testOrderDetails instance from its dict format. + + Args: + orderdetails_dict (dict): testOrderDetails in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestOrderDetails: orderdetails_obj instance. + + """ + return _TestOrderDetails(orderdetails_dict['id'], + _testCustomer_from_dict(orderdetails_dict['customer'], ctx)) + + +def _testOrder_from_dict(order_dict, ctx): + """ + Returns testOrder instance from its dict format. + + Args: + order_dict (dict): testOrder in dict format. + + ctx (SerializationContext): Metadata pertaining to the serialization + operation. + + Returns: + _TestOrder: order_obj instance. + + """ + return _TestOrder(_testOrderDetails_from_dict(order_dict['order_details'], ctx), + _testProduct_from_dict(order_dict['product'], ctx)) + + +async def test_json_record_serialization(kafka_cluster, load_file): + """ + Tests basic JsonSerializer and JsonDeserializer basic functionality. + + product.json from: + https://json-schema.org/learn/getting-started-step-by-step.html + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr) + value_deserializer = await AsyncJSONDeserializer(schema_str) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"productId": 1, + "productName": "An ice sculpture", + "price": 12.50, + "tags": ["cold", "ice"], + "dimensions": { + "length": 7.0, + "width": 12.0, + "height": 9.5 + }, + "warehouseLocation": { + "latitude": -78.75, + "longitude": 20.4 + }} + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([actual[k] == v for k, v in record.items()]) + + +async def test_json_record_serialization_incompatible(kafka_cluster, load_file): + """ + Tests Serializer validation functionality. + + product.json from: + https://json-schema.org/learn/getting-started-step-by-step.html + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr) + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"contractorId": 1, + "contractorName": "David Davidson", + "contractRate": 1250, + "trades": ["mason"]} + + with pytest.raises(ValueSerializationError, + match=r"(.*) is a required property"): + await producer.produce(topic, value=record, partition=0) + + +async def test_json_record_serialization_custom(kafka_cluster, load_file): + """ + Ensures to_dict and from_dict hooks are properly applied by the serializer. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("product.json") + value_serializer = await AsyncJSONSerializer(schema_str, sr, + to_dict=_testProduct_to_dict) + value_deserializer = await AsyncJSONDeserializer(schema_str, + from_dict=_testProduct_from_dict) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = _TestProduct(product_id=1, + name="The ice sculpture", + price=12.50, + tags=["cold", "ice"], + dimensions={"length": 7.0, + "width": 12.0, + "height": 9.5}, + location={"latitude": -78.75, + "longitude": 20.4}) + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([getattr(actual, attribute) == getattr(record, attribute) + for attribute in vars(record)]) + + +async def test_json_record_deserialization_mismatch(kafka_cluster, load_file): + """ + Ensures to_dict and from_dict hooks are properly applied by the serializer. + + Args: + kafka_cluster (KafkaClusterFixture): cluster fixture + + load_file (callable(str)): JSON Schema file reader + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + schema_str = load_file("contractor.json") + schema_str2 = load_file("product.json") + + value_serializer = await AsyncJSONSerializer(schema_str, sr) + value_deserializer = await AsyncJSONDeserializer(schema_str2) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + + record = {"contractorId": 2, + "contractorName": "Magnus Edenhill", + "contractRate": 30, + "trades": ["pickling"]} + + await producer.produce(topic, value=record, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + with pytest.raises( + ConsumeError, + match="'productId' is a required property"): + await consumer.poll() + + +async def _register_referenced_schemas(sr: AsyncSchemaRegistryClient, load_file): + await sr.register_schema("product", Schema(load_file("product.json"), 'JSON')) + await sr.register_schema("customer", Schema(load_file("customer.json"), 'JSON')) + await sr.register_schema("order_details", Schema(load_file("order_details.json"), 'JSON', [ + SchemaReference("http://example.com/customer.schema.json", "customer", 1)])) + + order_schema = Schema(load_file("order.json"), 'JSON', + [SchemaReference("http://example.com/order_details.schema.json", "order_details", 1), + SchemaReference("http://example.com/product.schema.json", "product", 1)]) + return order_schema + + +async def test_json_reference(kafka_cluster, load_file): + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + product = {"productId": 1, + "productName": "An ice sculpture", + "price": 12.50, + "tags": ["cold", "ice"], + "dimensions": { + "length": 7.0, + "width": 12.0, + "height": 9.5 + }, + "warehouseLocation": { + "latitude": -78.75, + "longitude": 20.4 + }} + customer = {"name": "John Doe", "id": 1} + order_details = {"id": 1, "customer": customer} + order = {"order_details": order_details, "product": product} + + schema = await _register_referenced_schemas(sr, load_file) + + value_serializer = await AsyncJSONSerializer(schema, sr) + value_deserializer = await AsyncJSONDeserializer(schema, schema_registry_client=sr) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + await producer.produce(topic, value=order, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert all([actual[k] == v for k, v in order.items()]) + + +async def test_json_reference_custom(kafka_cluster, load_file): + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-json") + sr = kafka_cluster.async_schema_registry() + + product = _TestProduct(product_id=1, + name="The ice sculpture", + price=12.50, + tags=["cold", "ice"], + dimensions={"length": 7.0, + "width": 12.0, + "height": 9.5}, + location={"latitude": -78.75, + "longitude": 20.4}) + customer = _TestCustomer(name="John Doe", id=1) + order_details = _TestOrderDetails(id=1, customer=customer) + order = _TestOrder(order_details=order_details, product=product) + + schema = await _register_referenced_schemas(sr, load_file) + + value_serializer = await AsyncJSONSerializer(schema, sr, to_dict=_testOrder_to_dict) + value_deserializer = await AsyncJSONDeserializer(schema, schema_registry_client=sr, from_dict=_testOrder_from_dict) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + await producer.produce(topic, value=order, partition=0) + producer.flush() + + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + msg = await consumer.poll() + actual = msg.value() + + assert actual == order diff --git a/tests/integration/schema_registry/_async/test_proto_serializers.py b/tests/integration/schema_registry/_async/test_proto_serializers.py new file mode 100644 index 000000000..5f0b9f413 --- /dev/null +++ b/tests/integration/schema_registry/_async/test_proto_serializers.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# +# Copyright 2016 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +from confluent_kafka import TopicPartition, KafkaException, KafkaError +from confluent_kafka.error import ConsumeError +from confluent_kafka.schema_registry.protobuf import AsyncProtobufSerializer, AsyncProtobufDeserializer +from tests.integration.schema_registry.data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ + PublicTestProto_pb2 +from tests.integration.schema_registry.data.proto.DependencyTestProto_pb2 import DependencyMessage +from tests.integration.schema_registry.data.proto.exampleProtoCriteo_pb2 import ClickCas + + +@pytest.mark.parametrize("pb2, data", [ + (TestProto_pb2.TestMessage, {'test_string': "abc", + 'test_bool': True, + 'test_bytes': b'look at these bytes', + 'test_double': 1.0, + 'test_float': 12.0}), + (PublicTestProto_pb2.TestMessage, {'test_string': "abc", + 'test_bool': True, + 'test_bytes': b'look at these bytes', + 'test_double': 1.0, + 'test_float': 12.0}), + (NestedTestProto_pb2.NestedMessage, {'user_id': + NestedTestProto_pb2.UserId( + kafka_user_id='oneof_str'), + 'is_active': True, + 'experiments_active': ['x', 'y', '1'], + 'status': NestedTestProto_pb2.INACTIVE, + 'complex_type': + NestedTestProto_pb2.ComplexType( + one_id='oneof_str', + is_active=False)}) +]) +async def test_protobuf_message_serialization(kafka_cluster, pb2, data): + """ + Validates that we get the same message back that we put in. + + """ + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto") + sr = kafka_cluster.async_schema_registry() + + value_serializer = await AsyncProtobufSerializer(pb2, sr, {'use.deprecated.format': False}) + value_deserializer = await AsyncProtobufDeserializer(pb2, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(value_serializer=value_serializer) + consumer = kafka_cluster.async_consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + expect = pb2(**data) + await producer.produce(topic, value=expect, partition=0) + producer.flush() + + msg = await consumer.poll() + actual = msg.value() + + assert [getattr(expect, k) == getattr(actual, k) for k in data.keys()] + + +@pytest.mark.parametrize("pb2, expected_refs", [ + (TestProto_pb2.TestMessage, ['google/protobuf/descriptor.proto']), + (NestedTestProto_pb2.NestedMessage, ['google/protobuf/timestamp.proto']), + (DependencyMessage, ['NestedTestProto.proto', 'PublicTestProto.proto']), + (ClickCas, ['metadata_proto.proto', 'common_proto.proto']) +]) +async def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): + """ + Registers multiple messages with dependencies then queries the Schema + Registry to ensure the references match up. + + """ + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2, sr, {'use.deprecated.format': False}) + producer = kafka_cluster.async_producer(key_serializer=serializer) + + await producer.produce(topic, key=pb2(), partition=0) + producer.flush() + + registered_refs = (await sr.get_schema(serializer._schema_id)).references + + assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() + + +async def test_protobuf_serializer_type_mismatch(kafka_cluster): + """ + Ensures an Exception is raised when deserializing an unexpected type. + + """ + pb2_1 = TestProto_pb2.TestMessage + pb2_2 = NestedTestProto_pb2.NestedMessage + + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2_1, sr, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(key_serializer=serializer) + + with pytest.raises(KafkaException, + match=r"message must be of type not \"): + await producer.produce(topic, key=pb2_2()) + + +async def test_protobuf_deserializer_type_mismatch(kafka_cluster): + """ + Ensures an Exception is raised when deserializing an unexpected type. + + """ + pb2_1 = PublicTestProto_pb2.TestMessage + pb2_2 = metadata_proto_pb2.HDFSOptions + + sr = kafka_cluster.async_schema_registry() + topic = kafka_cluster.create_topic_and_wait_propogation("serialization-proto-refs") + serializer = await AsyncProtobufSerializer(pb2_1, sr, {'use.deprecated.format': False}) + deserializer = await AsyncProtobufDeserializer(pb2_2, {'use.deprecated.format': False}) + + producer = kafka_cluster.async_producer(key_serializer=serializer) + consumer = kafka_cluster.async_consumer(key_deserializer=deserializer) + consumer.assign([TopicPartition(topic, 0)]) + + def dr(err, msg): + print("dr msg {} {}".format(msg.key(), msg.value())) + + await producer.produce(topic, key=pb2_1(test_string='abc', + test_bool=True, + test_bytes=b'def'), + partition=0) + producer.flush() + + with pytest.raises(ConsumeError) as e: + await consumer.poll() + assert e.value.code == KafkaError._KEY_DESERIALIZATION From 32507b5fd338d9844200afb0be4b12a40617eec4 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 17 Apr 2025 14:32:07 -0700 Subject: [PATCH 14/32] refactor in prep for adding async --- src/confluent_kafka/schema_registry/avro.py | 1 + .../schema_registry/common/__init__.py | 35 ++++++ .../schema_registry/json_schema.py | 1 + .../schema_registry/protobuf.py | 1 + .../schema_registry/schema_registry_client.py | 1 + src/confluent_kafka/schema_registry/serde.py | 1 + tests/common/_async/consumer.py | 100 +++++++++++++++ tests/common/_async/producer.py | 116 ++++++++++++++++++ 8 files changed, 256 insertions(+) create mode 100644 tests/common/_async/consumer.py create mode 100644 tests/common/_async/producer.py diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 0a9ba2db2..94dcd90ec 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -17,3 +17,4 @@ from .common.avro import * # noqa from ._sync.avro import * # noqa +from ._async.avro import * # noqa diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py index e69de29bb..abbd1c448 100644 --- a/src/confluent_kafka/schema_registry/common/__init__.py +++ b/src/confluent_kafka/schema_registry/common/__init__.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +def asyncinit(cls): + """ + Decorator to make a class async-initializable. + """ + __new__ = cls.__new__ + + async def init(obj, *arg, **kwarg): + await obj.__init__(*arg, **kwarg) + return obj + + def new(klass, *arg, **kwarg): + obj = __new__(klass) + coro = init(obj, *arg, **kwarg) + return coro + + cls.__new__ = new + return cls diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index e60c8eafd..ca4d74ee9 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -16,4 +16,5 @@ # limitations under the License. from .common.json_schema import * # noqa +from ._async.json_schema import * # noqa from ._sync.json_schema import * # noqa diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index c781e47be..5127398b6 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -16,4 +16,5 @@ # limitations under the License. from .common.protobuf import * # noqa +from ._async.protobuf import * # noqa from ._sync.protobuf import * # noqa diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index b7a008097..be6c70119 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -17,6 +17,7 @@ from .common.schema_registry_client import * # noqa +from ._async.schema_registry_client import * # noqa from ._sync.schema_registry_client import * # noqa from .error import SchemaRegistryError # noqa diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 25fe4a48f..0574fc3b1 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -17,4 +17,5 @@ # from .common.serde import * # noqa +from ._async.serde import * # noqa from ._sync.serde import * # noqa diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py new file mode 100644 index 000000000..bb74c7b51 --- /dev/null +++ b/tests/common/_async/consumer.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio + +from confluent_kafka.cimpl import Consumer +from confluent_kafka.error import ConsumeError, KeyDeserializationError, ValueDeserializationError +from confluent_kafka.serialization import MessageField, SerializationContext + +ASYNC_CONSUMER_POLL_INTERVAL_SECONDS: int = 0.2 +ASYNC_CONSUMER_POLL_INFINITE_TIMEOUT_SECONDS: int = -1 + +class AsyncConsumer(Consumer): + def __init__( + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval_seconds: int = ASYNC_CONSUMER_POLL_INTERVAL_SECONDS + ): + super().__init__(conf) + + self._loop = loop or asyncio.get_event_loop() + self._poll_interval = poll_interval_seconds + + def __aiter__(self): + return self + + async def __anext__(self): + return await self.poll(None) + + async def poll(self, timeout: int = -1): + timeout = None if timeout == -1 else timeout + async with asyncio.timeout(timeout): + while True: + # Zero timeout here is what makes it non-blocking + msg = super().poll(0) + if msg is not None: + return msg + else: + await asyncio.sleep(self._poll_interval) + + +class TestAsyncDeserializingConsumer(AsyncConsumer): + def __init__(self, conf): + conf_copy = conf.copy() + self._key_deserializer = conf_copy.pop('key.deserializer', None) + self._value_deserializer = conf_copy.pop('value.deserializer', None) + super().__init__(conf_copy) + + async def poll(self, timeout=-1): + msg = await super().poll(timeout) + + if msg is None: + return None + + if msg.error() is not None: + raise ConsumeError(msg.error(), kafka_message=msg) + + ctx = SerializationContext(msg.topic(), MessageField.VALUE, msg.headers()) + value = msg.value() + if self._value_deserializer is not None: + try: + value = await self._value_deserializer(value, ctx) + except Exception as se: + raise ValueDeserializationError(exception=se, kafka_message=msg) + + key = msg.key() + ctx.field = MessageField.KEY + if self._key_deserializer is not None: + try: + key = await self._key_deserializer(key, ctx) + except Exception as se: + raise KeyDeserializationError(exception=se, kafka_message=msg) + + msg.set_key(key) + msg.set_value(value) + return msg + + def consume(self, num_messages=1, timeout=-1): + """ + :py:func:`Consumer.consume` not implemented, use + :py:func:`DeserializingConsumer.poll` instead + """ + + raise NotImplementedError diff --git a/tests/common/_async/producer.py b/tests/common/_async/producer.py new file mode 100644 index 000000000..e9f811700 --- /dev/null +++ b/tests/common/_async/producer.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2025 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from confluent_kafka.cimpl import Producer +import inspect +import asyncio + +from confluent_kafka.error import KeySerializationError, ValueSerializationError +from confluent_kafka.serialization import MessageField, SerializationContext + +ASYNC_PRODUCER_POLL_INTERVAL: int = 0.2 + +class AsyncProducer(Producer): + def __init__( + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval: int = ASYNC_PRODUCER_POLL_INTERVAL + ): + super().__init__(conf) + + self._loop = loop or asyncio.get_event_loop() + self._poll_interval = poll_interval + + self._poll_task = None + self._waiters: int = 0 + + async def produce( + self, topic, value=None, key=None, partition=-1, + on_delivery=None, timestamp=0, headers=None + ): + fut = self._loop.create_future() + self._waiters += 1 + try: + if self._poll_task is None or self._poll_task.done(): + self._poll_task = asyncio.create_task(self._poll_dr(self._poll_interval)) + + def wrapped_on_delivery(err, msg): + if on_delivery is not None: + if inspect.iscoroutinefunction(on_delivery): + asyncio.run_coroutine_threadsafe( + on_delivery(err, msg), + self._loop + ) + else: + self._loop.call_soon_threadsafe(on_delivery, err, msg) + + if err: + self._loop.call_soon_threadsafe(fut.set_exception, err) + else: + self._loop.call_soon_threadsafe(fut.set_result, msg) + + super().produce( + topic, + value, + key, + headers=headers, + partition=partition, + timestamp=timestamp, + on_delivery=wrapped_on_delivery + ) + return await fut + finally: + self._waiters -= 1 + + async def _poll_dr(self, interval: int): + """Poll delivery reports at interval seconds""" + while self._waiters: + super().poll(0) + await asyncio.sleep(interval) + + +class TestAsyncSerializingProducer(AsyncProducer): + def __init__(self, conf): + conf_copy = conf.copy() + + self._key_serializer = conf_copy.pop('key.serializer', None) + self._value_serializer = conf_copy.pop('value.serializer', None) + + super(TestAsyncSerializingProducer, self).__init__(conf_copy) + + async def produce(self, topic, key=None, value=None, partition=-1, + on_delivery=None, timestamp=0, headers=None): + ctx = SerializationContext(topic, MessageField.KEY, headers) + if self._key_serializer is not None: + try: + key = await self._key_serializer(key, ctx) + except Exception as se: + raise KeySerializationError(se) + ctx.field = MessageField.VALUE + if self._value_serializer is not None: + try: + value = await self._value_serializer(value, ctx) + except Exception as se: + raise ValueSerializationError(se) + + return await super().produce(topic, value, key, + headers=headers, + partition=partition, + timestamp=timestamp, + on_delivery=on_delivery) From b1185e055df0eb36eec41b753037de2300a6f98d Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 17 Apr 2025 14:36:01 -0700 Subject: [PATCH 15/32] Implement async variants of SR clients and serdes --- src/confluent_kafka/schema_registry/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 2d81f44ac..25f21db7f 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -57,6 +57,7 @@ "RuleSet", "Schema", "SchemaRegistryClient", + "AsyncSchemaRegistryClient", "SchemaRegistryError", "SchemaReference", "ServerConfig", From 63d385dba7b7e9515f7bef137abe93aab73a7a84 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Mon, 21 Apr 2025 11:18:26 -0700 Subject: [PATCH 16/32] Refactor --- .../schema_registry/_async/avro.py | 8 +++++- .../schema_registry/_async/json_schema.py | 5 ++++ .../schema_registry/_async/protobuf.py | 5 ++++ .../_async/schema_registry_client.py | 25 +++++++++---------- .../schema_registry/_async/serde.py | 6 +++++ .../schema_registry/_sync/avro.py | 1 + 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 713182408..69cb87f1a 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -30,10 +30,16 @@ AsyncSchemaRegistryClient) from confluent_kafka.serialization import (SerializationError, SerializationContext) -from confluent_kafka.schema_registry.common import _ContextStringIO, asyncinit +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.common import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache +__all__ = [ + '_resolve_named_schema', + 'AsyncAvroSerializer', + 'AsyncAvroDeserializer', +] async def _resolve_named_schema( schema: Schema, schema_registry_client: AsyncSchemaRegistryClient diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index c12e5869c..a1d6d6b7a 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -40,6 +40,11 @@ from confluent_kafka.serialization import (SerializationError, SerializationContext) +__all__ = [ + '_resolve_named_schema', + 'AsyncJSONSerializer', + 'AsyncJSONDeserializer' +] async def _resolve_named_schema( schema: Schema, schema_registry_client: AsyncSchemaRegistryClient, diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index f6e850dea..9fd744034 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -40,6 +40,11 @@ from confluent_kafka.schema_registry.common import asyncinit from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache +__all__ = [ + '_resolve_named_schema', + 'AsyncProtobufSerializer', + 'AsyncProtobufDeserializer', +] async def _resolve_named_schema( schema: Schema, diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index ea49e4bd0..223848558 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -38,10 +38,20 @@ is_retriable, _BearerFieldProvider, full_jitter, - _SchemaCache, + _SchemaCache, Schema, + _StaticFieldProvider, ) +__all__ = [ + '_urlencode', + '_AsyncCustomOAuthClient', + '_AsyncOAuthClient', + '_AsyncBaseRestClient', + '_AsyncRestClient', + 'AsyncSchemaRegistryClient', +] + # TODO: consider adding `six` dependency or employing a compat file # Python 2.7 is officially EOL so compatibility issue will be come more the norm. # We need a better way to handle these issues. @@ -64,17 +74,6 @@ def _urlencode(value: str) -> str: log = logging.getLogger(__name__) -class _AsyncStaticFieldProvider(_BearerFieldProvider): - def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - - async def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - class _AsyncCustomOAuthClient(_BearerFieldProvider): def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): self.custom_function = custom_function @@ -292,7 +291,7 @@ def __init__(self, conf: dict): if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _AsyncStaticFieldProvider(static_token, logical_cluster, identity_pool) + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) if not isinstance(static_token, string_type): raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) elif self.bearer_auth_credentials_source == 'CUSTOM': diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 677d916f7..f72d108c6 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -26,6 +26,12 @@ from confluent_kafka.serialization import Serializer, Deserializer, \ SerializationContext, SerializationError +__all__ = [ + 'AsyncBaseSerde', + 'AsyncBaseSerializer', + 'AsyncBaseDeserializer', +] + log = logging.getLogger(__name__) class AsyncBaseSerde(object): diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 651dee7ca..147dd2a34 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -361,6 +361,7 @@ def _get_parsed_schema(self, schema: Schema) -> AvroSchema: return parsed_schema + class AvroDeserializer(BaseDeserializer): """ Deserializer for Avro binary encoded data with Confluent Schema Registry From f9874ca64240dbb96301d9ad5c24a7cd26311da2 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 23 Apr 2025 11:39:17 -0700 Subject: [PATCH 17/32] make functions async --- .../schema_registry/_async/schema_registry_client.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 223848558..29b83e2ce 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -321,16 +321,16 @@ def __init__(self, conf: dict): raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) - def get(self, url: str, query: Optional[dict] = None) -> Any: + async def get(self, url: str, query: Optional[dict] = None) -> Any: raise NotImplementedError() - def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + async def post(self, url: str, body: Optional[dict], **kwargs) -> Any: raise NotImplementedError() - def delete(self, url: str) -> Any: + async def delete(self, url: str) -> Any: raise NotImplementedError() - def put(self, url: str, body: Optional[dict] = None) -> Any: + async def put(self, url: str, body: Optional[dict] = None) -> Any: raise NotImplementedError() @@ -560,9 +560,9 @@ class AsyncSchemaRegistryClient(object): `Confluent Schema Registry documentation `_ """ # noqa: E501 - def __init__(self, conf: dict): + def __init__(self, conf: dict, rest_client: _AsyncRestClient = None): self._conf = conf - self._rest_client = _AsyncRestClient(conf) + self._rest_client = rest_client or _AsyncRestClient(conf) self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec From 2365f8de7428913fa66a58ada98c9c47b57c95b4 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 16:09:20 -0700 Subject: [PATCH 18/32] make schema id serializer changes to async --- .../schema_registry/_async/__init__.py | 17 - .../schema_registry/_async/avro.py | 156 +++++---- .../schema_registry/_async/json_schema.py | 172 +++++----- .../schema_registry/_async/protobuf.py | 308 +++++++----------- .../_async/schema_registry_client.py | 73 ++++- .../schema_registry/_async/serde.py | 5 +- .../_sync/test_proto_serializers.py | 2 +- 7 files changed, 354 insertions(+), 379 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/__init__.py b/src/confluent_kafka/schema_registry/_async/__init__.py index 2b4389a06..e69de29bb 100644 --- a/src/confluent_kafka/schema_registry/_async/__init__.py +++ b/src/confluent_kafka/schema_registry/_async/__init__.py @@ -1,17 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright 2020 Confluent Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 69cb87f1a..c5f0d1dc3 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -20,20 +20,22 @@ from typing import Dict, Union, Optional, Callable from fastavro import schemaless_reader, schemaless_writer - -from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, get_inline_tags, parse_schema_with_repo, transform +from confluent_kafka.schema_registry.common import asyncinit +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ + get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO, AVRO_TYPE from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleMode, - AsyncSchemaRegistryClient) + Schema, + topic_subject_name_strategy, + RuleMode, + AsyncSchemaRegistryClient, + prefix_schema_id_serializer, + dual_schema_id_deserializer) from confluent_kafka.serialization import (SerializationError, SerializationContext) -from confluent_kafka.schema_registry.common import asyncinit -from confluent_kafka.schema_registry.common import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache, SchemaId + __all__ = [ '_resolve_named_schema', @@ -41,6 +43,7 @@ 'AsyncAvroDeserializer', ] + async def _resolve_named_schema( schema: Schema, schema_registry_client: AsyncSchemaRegistryClient ) -> Dict[str, AvroSchema]: @@ -113,6 +116,12 @@ class AsyncAvroSerializer(AsyncBaseSerializer): | | | | | | | Defaults to topic_subject_name_strategy. | +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-----------------------------+----------+--------------------------------------------------+ Schemas are registered against subject names in Confluent Schema Registry that define a scope in which the schemas can be evolved. By default, the subject name @@ -172,7 +181,8 @@ class AsyncAvroSerializer(AsyncBaseSerializer): 'use.schema.id': None, 'use.latest.version': False, 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer} async def __init__( self, @@ -234,6 +244,10 @@ async def __init__( self._subject_name_func = conf_copy.pop('subject.name.strategy') if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") + + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" @@ -297,19 +311,20 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N subject = self._subject_name_func(ctx, self._schema_name) latest_schema = await self._get_reader_schema(subject) if latest_schema is not None: - self._schema_id = latest_schema.schema_id + self._schema_id = SchemaId(AVRO_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register # a schema without a subject so we set the schema_id here to handle # the initial registration. - self._schema_id = await self._registry.register_schema( + registered_schema = await self._registry.register_schema_full_response( subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(AVRO_TYPE, registered_schema.schema_id, registered_schema.guid) else: registered_schema = await self._registry.lookup_schema( subject, self._schema, self._normalize_schemas) - self._schema_id = registered_schema.schema_id + self._schema_id = SchemaId(AVRO_TYPE, registered_schema.schema_id, registered_schema.guid) self._known_subjects.add(subject) @@ -320,7 +335,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N if latest_schema is not None: parsed_schema = await self._get_parsed_schema(latest_schema.schema) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, msg, field_transform)) value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, get_inline_tags(parsed_schema), @@ -334,7 +349,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) - return fo.getvalue() + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: parsed_schema = self._parsed_schemas.get_parsed_schema(schema) @@ -378,7 +393,12 @@ class AsyncAvroDeserializer(AsyncBaseDeserializer): | | | | | | | Defaults to topic_subject_name_strategy. | +-----------------------------+----------+--------------------------------------------------+ - + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-----------------------------+----------+--------------------------------------------------+ Note: By default, Avro complex types are returned as dicts. This behavior can be overridden by registering a callable ``from_dict`` with the deserializer to @@ -415,7 +435,8 @@ class AsyncAvroDeserializer(AsyncBaseDeserializer): _default_conf = {'use.latest.version': False, 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} + 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer} async def __init__( self, @@ -460,6 +481,11 @@ async def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + + if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) @@ -513,61 +539,61 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = "message was not produced with a Confluent " "Schema Registry serializer".format(len(data))) - subject = self._subject_name_func(ctx, None) + subject = self._subject_name_func(ctx, None) if ctx else None latest_schema = None if subject is not None: latest_schema = await self._get_reader_schema(subject) - with _ContextStringIO(data) as payload: - magic, schema_id = unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - writer_schema_raw = await self._registry.get_schema(schema_id) - writer_schema = await self._get_parsed_schema(writer_schema_raw) - - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) - if subject is not None: - latest_schema = await self._get_reader_schema(subject) - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema = await self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema = self._reader_schema - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if migrations: - obj_dict = schemaless_reader(payload, - writer_schema, - None, - self._return_record_name) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - else: - obj_dict = schemaless_reader(payload, - writer_schema, - reader_schema, - self._return_record_name) + schema_id = SchemaId(AVRO_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + + writer_schema_raw = self._get_writer_schema(schema_id, subject) + writer_schema = self._get_parsed_schema(writer_schema_raw) + + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema = self._reader_schema + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if migrations: + obj_dict = schemaless_reader(payload, + writer_schema, + None, + self._return_record_name) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + else: + obj_dict = schemaless_reader(payload, + writer_schema, + reader_schema, + self._return_record_name) + + + + - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) - return obj_dict + return obj_dict async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: parsed_schema = self._parsed_schemas.get_parsed_schema(schema) diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index a1d6d6b7a..f03878b42 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -28,15 +28,16 @@ from confluent_kafka.schema_registry import (_MAGIC_BYTE, Schema, topic_subject_name_strategy, - RuleMode, AsyncSchemaRegistryClient) + RuleMode, AsyncSchemaRegistryClient, + prefix_schema_id_serializer, + dual_schema_id_deserializer) from confluent_kafka.schema_registry.common import asyncinit from confluent_kafka.schema_registry.common.json_schema import ( - DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE ) -from confluent_kafka.schema_registry.common import _ContextStringIO from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ - ParsedSchemaCache + ParsedSchemaCache, SchemaId from confluent_kafka.serialization import (SerializationError, SerializationContext) @@ -46,6 +47,7 @@ 'AsyncJSONDeserializer' ] + async def _resolve_named_schema( schema: Schema, schema_registry_client: AsyncSchemaRegistryClient, ref_registry: Optional[Registry] = None @@ -70,7 +72,6 @@ async def _resolve_named_schema( ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry - @asyncinit class AsyncJSONSerializer(AsyncBaseSerializer): """ @@ -131,6 +132,12 @@ class AsyncJSONSerializer(AsyncBaseSerializer): | ``validate`` | bool | the given schema. | | | | | +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-----------------------------+----------+----------------------------------------------------+ Schemas are registered against subject names in Confluent Schema Registry that define a scope in which the schemas can be evolved. By default, the subject name @@ -192,6 +199,7 @@ class AsyncJSONSerializer(AsyncBaseSerializer): 'use.latest.version': False, 'use.latest.with.metadata': None, 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer, 'validate': True} async def __init__( @@ -260,6 +268,10 @@ async def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + self._validate = conf_copy.pop('validate') if not isinstance(self._normalize_schemas, bool): raise ValueError("validate must be a boolean value") @@ -310,21 +322,20 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N subject = self._subject_name_func(ctx, self._schema_name) latest_schema = await self._get_reader_schema(subject) if latest_schema is not None: - self._schema_id = latest_schema.schema_id + self._schema_id = SchemaId(JSON_TYPE, latest_schema.schema_id, latest_schema.guid) elif subject not in self._known_subjects: # Check to ensure this schema has been registered under subject_name. if self._auto_register: # The schema name will always be the same. We can't however register # a schema without a subject so we set the schema_id here to handle # the initial registration. - self._schema_id = await self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) + registered_schema = await self._registry.register_schema_full_response( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(JSON_TYPE, registered_schema.schema_id, registered_schema.guid) else: - registered_schema = await self._registry.lookup_schema(subject, - self._schema, - self._normalize_schemas) - self._schema_id = registered_schema.schema_id + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(JSON_TYPE, registered_schema.schema_id, registered_schema.guid) self._known_subjects.add(subject) @@ -339,7 +350,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N root_resource = Resource.from_contents( parsed_schema, default_specification=DEFAULT_SPEC) ref_resolver = ref_registry.resolver_with_root(root_resource) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, value, None, @@ -356,8 +367,6 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N raise SerializationError(ve.message) with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) # JSON dump always writes a str never bytes # https://docs.python.org/3/library/json.html encoded_value = self._json_encode(value) @@ -365,7 +374,7 @@ async def __serialize(self, obj: object, ctx: Optional[SerializationContext] = N encoded_value = encoded_value.encode("utf8") fo.write(encoded_value) - return fo.getvalue() + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: if schema is None: @@ -428,6 +437,12 @@ class AsyncJSONDeserializer(AsyncBaseDeserializer): | ``validate`` | bool | the given schema. | | | | | +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-----------------------------+----------+----------------------------------------------------+ Args: schema_str (str, Schema, optional): @@ -450,6 +465,7 @@ class AsyncJSONDeserializer(AsyncBaseDeserializer): _default_conf = {'use.latest.version': False, 'use.latest.with.metadata': None, 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer, 'validate': True} async def __init__( @@ -504,6 +520,10 @@ async def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._subject_name_func): + raise ValueError("schema.id.deserializer must be callable") + self._validate = conf_copy.pop('validate') if not isinstance(self._validate, bool): raise ValueError("validate must be a boolean value") @@ -551,75 +571,65 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N if data is None: return None - if len(data) <= 5: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - subject = self._subject_name_func(ctx, None) latest_schema = None if subject is not None and self._registry is not None: latest_schema = await self._get_reader_schema(subject) - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - # JSON documents are self-describing; no need to query schema - obj_dict = self._json_decode(payload.read()) - - if self._registry is not None: - writer_schema_raw = await self._registry.get_schema(schema_id) - writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("title")) - if subject is not None: - latest_schema = await self._get_reader_schema(subject) - else: - writer_schema_raw = None - writer_schema, writer_ref_registry = None, None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - - if migrations: - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) - - if self._validate: - try: - validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) - validator.validate(obj_dict) - except ValidationError as ve: - raise SerializationError(ve.message) - - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) - - return obj_dict + schema_id = SchemaId(JSON_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + + if self._registry is not None: + writer_schema_raw = await self._get_writer_schema(schema_id, subject) + writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("title")) + if subject is not None: + latest_schema = await self._get_reader_schema(subject) + else: + writer_schema_raw = None + writer_schema, writer_ref_registry = None, None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema, reader_ref_registry = writer_schema, writer_ref_registry + + if migrations: + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) + + if self._validate: + try: + validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) + validator.validate(obj_dict) + except ValidationError as ve: + raise SerializationError(ve.message) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: if schema is None: diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 9fd744034..7d20cde68 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -26,19 +26,21 @@ from google.protobuf.message import DecodeError, Message from google.protobuf.message_factory import GetMessageClass -from confluent_kafka.schema_registry.common import (_MAGIC_BYTE, _ContextStringIO, - reference_subject_name_strategy, - topic_subject_name_strategy) -from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient -from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + reference_subject_name_strategy, + topic_subject_name_strategy, + prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ + _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry import (Schema, - SchemaReference, - RuleMode) + SchemaReference, + RuleMode) from confluent_kafka.serialization import SerializationError, \ SerializationContext from confluent_kafka.schema_registry.common import asyncinit -from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache, SchemaId __all__ = [ '_resolve_named_schema', @@ -46,6 +48,7 @@ 'AsyncProtobufDeserializer', ] + async def _resolve_named_schema( schema: Schema, schema_registry_client: AsyncSchemaRegistryClient, @@ -71,7 +74,6 @@ async def _resolve_named_schema( file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) - @asyncinit class AsyncProtobufSerializer(AsyncBaseSerializer): """ @@ -141,6 +143,12 @@ class AsyncProtobufSerializer(AsyncBaseSerializer): | | | | | | | Defaults to reference_subject_name_strategy | +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> bytes | + | | | | + | ``schema.id.serializer`` | callable | Defines how the schema id/guid is serialized. | + | | | Defaults to prefix_schema_id_serializer. | + +-------------------------------------+----------+------------------------------------------------------+ | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | | | | serialize message indexes without zig-zag encoding. | | | | This option must be explicitly configured as older | @@ -202,6 +210,7 @@ class AsyncProtobufSerializer(AsyncBaseSerializer): 'skip.known.types': True, 'subject.name.strategy': topic_subject_name_strategy, 'reference.subject.name.strategy': reference_subject_name_strategy, + 'schema.id.serializer': prefix_schema_id_serializer, 'use.deprecated.format': False, } @@ -215,13 +224,6 @@ async def __init__( ): super().__init__() - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufSerializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - conf_copy = self._default_conf.copy() if conf is not None: conf_copy.update(conf) @@ -276,6 +278,10 @@ async def __init__( if not callable(self._ref_reference_subject_func): raise ValueError("subject.name.strategy must be callable") + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") + if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" .format(", ".join(conf_copy.keys()))) @@ -398,14 +404,16 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext raise ValueError("message must be of type {} not {}" .format(self._msg_class, type(message))) - subject = self._subject_name_func(ctx, - message.DESCRIPTOR.full_name) - latest_schema = await self._get_reader_schema(subject, fmt='serialized') + subject = self._subject_name_func(ctx, message.DESCRIPTOR.full_name) if ctx else None + latest_schema = None + if subject is not None: + latest_schema = await self._get_reader_schema(subject, fmt='serialized') + if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - references = await self._resolve_dependencies( - ctx, message.DESCRIPTOR.file) + self._schema_id = SchemaId(PROTOBUF_TYPE, latest_schema.schema_id, latest_schema.guid) + + elif subject not in self._known_subjects and ctx is not None: + references = await self._resolve_dependencies(ctx, message.DESCRIPTOR.file) self._schema = Schema( self._schema.schema_str, self._schema.schema_type, @@ -413,12 +421,13 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext ) if self._auto_register: - self._schema_id = await self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) + registered_schema = await self._registry.register_schema_full_response( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(PROTOBUF_TYPE, registered_schema.schema_id, registered_schema.guid) else: - self._schema_id = await self._registry.lookup_schema( - subject, self._schema, self._normalize_schemas).schema_id + registered_schema = await self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = SchemaId(PROTOBUF_TYPE, registered_schema.schema_id, registered_schema.guid) self._known_subjects.add(subject) @@ -426,23 +435,16 @@ async def __serialize(self, message: Message, ctx: Optional[SerializationContext fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) fd = pool.FindFileByName(fd_proto.name) desc = fd.message_types_by_name[message.DESCRIPTOR.name] - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 transform(rule_ctx, desc, msg, field_transform)) message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, latest_schema.schema, message, None, field_transformer) with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order - # (big endian) - fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) - # write the index array that specifies the message descriptor - # of the serialized data. - self._encode_varints(fo, self._index_array, - zigzag=not self._use_deprecated_format) - # write the serialized data itself fo.write(message.SerializeToString()) - return fo.getvalue() + self._schema_id.message_indexes = self._index_array + return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: result = self._parsed_schemas.get_parsed_schema(schema) @@ -457,7 +459,6 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD self._parsed_schemas.set(schema, (fd_proto, pool)) return fd_proto, pool - @asyncinit class AsyncProtobufDeserializer(AsyncBaseDeserializer): """ @@ -491,6 +492,12 @@ class AsyncProtobufDeserializer(AsyncBaseDeserializer): | | | | | | | Defaults to topic_subject_name_strategy. | +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(bytes, SerializationContext, schema_id) | + | | | -> io.BytesIO | + | | | | + | ``schema.id.deserializer`` | callable | Defines how the schema id/guid is deserialized. | + | | | Defaults to dual_schema_id_deserializer. | + +-------------------------------------+----------+------------------------------------------------------+ | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | | | | deserialize message indexes without zig-zag encoding.| | | | This option must be explicitly configured as older | @@ -514,6 +521,7 @@ class AsyncProtobufDeserializer(AsyncBaseDeserializer): 'use.latest.version': False, 'use.latest.with.metadata': None, 'subject.name.strategy': topic_subject_name_strategy, + 'schema.id.deserializer': dual_schema_id_deserializer, 'use.deprecated.format': False, } @@ -532,16 +540,6 @@ async def __init__( self._parsed_schemas = ParsedSchemaCache() self._use_schema_id = None - # Require use.deprecated.format to be explicitly configured - # during a transitionary period since old/new format are - # incompatible. - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufDeserializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - conf_copy = self._default_conf.copy() if conf is not None: conf_copy.update(conf) @@ -559,6 +557,10 @@ async def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') if not isinstance(self._use_deprecated_format, bool): raise ValueError("use.deprecated.format must be a boolean value") @@ -579,85 +581,6 @@ async def __init__( rule.configure(self._registry.config() if self._registry else {}, rule_conf if rule_conf else {}) - @staticmethod - def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: - """ - Decodes a single varint from a buffer. - - Args: - buf (BytesIO): buffer to read from - zigzag (bool): decode as zigzag or uvarint - - Returns: - int: decoded varint - - Raises: - EOFError: if buffer is empty - """ - - value = 0 - shift = 0 - try: - while True: - i = AsyncProtobufDeserializer._read_byte(buf) - - value |= (i & 0x7f) << shift - shift += 7 - if not (i & 0x80): - break - - if zigzag: - value = (value >> 1) ^ -(value & 1) - - return value - - except EOFError: - raise EOFError("Unexpected EOF while reading index") - - @staticmethod - def _read_byte(buf: io.BytesIO) -> int: - """ - Read one byte from buf as an int. - - Args: - buf (BytesIO): The buffer to read from. - - .. _ord: - https://docs.python.org/2/library/functions.html#ord - """ - - i = buf.read(1) - if i == b'': - raise EOFError("Unexpected EOF encountered") - return ord(i) - - @staticmethod - def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: - """ - Read an index array from buf that specifies the message - descriptor of interest in the file descriptor. - - Args: - buf (BytesIO): The buffer to read from. - - Returns: - list of int: The index array. - """ - - size = AsyncProtobufDeserializer._decode_varint(buf, zigzag=zigzag) - if size < 0 or size > 100000: - raise DecodeError("Invalid Protobuf msgidx array length") - - if size == 0: - return [0] - - msg_index = [] - for _ in range(size): - msg_index.append(AsyncProtobufDeserializer._decode_varint(buf, - zigzag=zigzag)) - - return msg_index - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: return self.__serialize(data, ctx) @@ -684,82 +607,69 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N if data is None: return None - # SR wire protocol + msg_index length - if len(data) < 6: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - subject = self._subject_name_func(ctx, None) latest_schema = None if subject is not None and self._registry is not None: latest_schema = await self._get_reader_schema(subject, fmt='serialized') - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unknown magic byte. This message was " - "not produced with a Confluent " - "Schema Registry serializer") - - msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) - - if self._registry is not None: - writer_schema_raw = await self._registry.get_schema(schema_id, fmt='serialized') - fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) - writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) - if subject is None: - subject = self._subject_name_func(ctx, writer_desc.full_name) - if subject is not None: - latest_schema = await self._get_reader_schema(subject, fmt='serialized') - else: - writer_schema_raw = None - writer_schema = None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) - reader_schema = pool.FindFileByName(fd_proto.name) - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if reader_schema is not None: - # Initialize reader desc to first message in file - reader_desc = self._get_message_desc(pool, reader_schema, [0]) - # Attempt to find a reader desc with the same name as the writer - reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - - if migrations: - msg = GetMessageClass(writer_desc)() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - obj_dict = json_format.MessageToDict(msg, True) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - msg = GetMessageClass(reader_desc)() - msg = json_format.ParseDict(obj_dict, msg) - else: - # Protobuf Messages are self-describing; no need to query schema - msg = self._msg_class() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) - - return msg + schema_id = SchemaId(PROTOBUF_TYPE) + payload = self._schema_id_deserializer(data, ctx, schema_id) + msg_index = schema_id.message_indexes + + if self._registry is not None: + writer_schema_raw = await self._get_writer_schema(schema_id, subject, fmt='serialized') + fd_proto, pool = await self._get_parsed_schema(writer_schema_raw) + writer_schema = pool.FindFileByName(fd_proto.name) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + if subject is None: + subject = self._subject_name_func(ctx, writer_desc.full_name) + if subject is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + else: + writer_schema_raw = None + writer_schema = None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + fd_proto, pool = await self._get_parsed_schema(latest_schema.schema) + reader_schema = pool.FindFileByName(fd_proto.name) + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if reader_schema is not None: + # Initialize reader desc to first message in file + reader_desc = self._get_message_desc(pool, reader_schema, [0]) + # Attempt to find a reader desc with the same name as the writer + reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) + + if migrations: + msg = GetMessageClass(writer_desc)() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + obj_dict = json_format.MessageToDict(msg, True) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + msg = GetMessageClass(reader_desc)() + msg = json_format.ParseDict(obj_dict, msg) + else: + # Protobuf Messages are self-describing; no need to query schema + msg = self._msg_class() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + transform(rule_ctx, reader_desc, message, field_transform)) + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) + return msg async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: result = self._parsed_schemas.get_parsed_schema(schema) diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index 29b83e2ce..e5f929beb 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -32,10 +32,10 @@ from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError from confluent_kafka.schema_registry.common.schema_registry_client import ( - RegisteredSchema, - ServerConfig, - is_success, - is_retriable, + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, _BearerFieldProvider, full_jitter, _SchemaCache, @@ -98,7 +98,7 @@ def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoin async def get_bearer_fields(self) -> dict: return { - 'bearer.auth.token': await self.get_access_token(), + 'bearer.auth.token': await self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, 'bearer.auth.identity.pool.id': self.identity_pool } @@ -560,9 +560,9 @@ class AsyncSchemaRegistryClient(object): `Confluent Schema Registry documentation `_ """ # noqa: E501 - def __init__(self, conf: dict, rest_client: _AsyncRestClient = None): + def __init__(self, conf: dict): self._conf = conf - self._rest_client = rest_client or _AsyncRestClient(conf) + self._rest_client = _AsyncRestClient(conf) self._cache = _SchemaCache() cache_capacity = self._rest_client.cache_capacity cache_ttl = self._rest_client.cache_latest_ttl_sec @@ -634,7 +634,9 @@ async def register_schema_full_response( schema_id = self._cache.get_id_by_schema(subject_name, schema) if schema_id is not None: - return RegisteredSchema(schema_id, schema, subject_name, None) + result = self._cache.get_schema_by_id(subject_name, schema_id) + if result is not None: + return RegisteredSchema(schema_id, result[0], result[1], subject_name, None) request = schema.to_dict() @@ -645,7 +647,9 @@ async def register_schema_full_response( registered_schema = RegisteredSchema.from_dict(response) # The registered schema may not be fully populated - self._cache.set_schema(subject_name, registered_schema.schema_id, schema) + s = registered_schema.schema if registered_schema.schema.schema_str is not None else schema + self._cache.set_schema(subject_name, registered_schema.schema_id, + registered_schema.guid, s) return registered_schema @@ -672,9 +676,9 @@ async def get_schema( `GET Schema API Reference `_ """ # noqa: E501 - schema = self._cache.get_schema_by_id(subject_name, schema_id) - if schema is not None: - return schema + result = self._cache.get_schema_by_id(subject_name, schema_id) + if result is not None: + return result[1] query = {'subject': subject_name} if subject_name is not None else None if fmt is not None: @@ -684,11 +688,49 @@ async def get_schema( query = {'format': fmt} response = await self._rest_client.get('schemas/ids/{}'.format(schema_id), query) - schema = Schema.from_dict(response) + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_schema(subject_name, schema_id, + registered_schema.guid, registered_schema.schema) + + return registered_schema.schema + + async def get_schema_by_guid( + self, guid: str, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``guid`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + guid (str): Schema guid + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``guid`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + schema = self._cache.get_schema_by_guid(guid) + if schema is not None: + return schema + + if fmt is not None: + query = {'format': fmt} + response = await self._rest_client.get('schemas/guids/{}'.format(guid), query) + + registered_schema = RegisteredSchema.from_dict(response) - self._cache.set_schema(subject_name, schema_id, schema) + self._cache.set_schema(None, registered_schema.schema_id, + registered_schema.guid, registered_schema.schema) - return schema + return registered_schema.schema async def lookup_schema( self, subject_name: str, schema: 'Schema', @@ -728,6 +770,7 @@ async def lookup_schema( # Ensure the schema matches the input registered_schema = RegisteredSchema( schema_id=result.schema_id, + guid=result.guid, subject=result.subject, version=result.version, schema=schema, diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index f72d108c6..7824740f6 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -20,7 +20,9 @@ from typing import List, Optional, Set, Dict, Any from confluent_kafka.schema_registry import RegisteredSchema -from confluent_kafka.schema_registry.common.serde import ErrorAction, FieldTransformer, Migration, NoneAction, RuleAction, RuleConditionError, RuleContext, RuleError +from confluent_kafka.schema_registry.common.serde import ErrorAction, \ + FieldTransformer, Migration, NoneAction, RuleAction, \ + RuleConditionError, RuleContext, RuleError from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ Rule, RuleKind, Schema, RuleSet from confluent_kafka.serialization import Serializer, Deserializer, \ @@ -34,6 +36,7 @@ log = logging.getLogger(__name__) + class AsyncBaseSerde(object): __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', '_registry', '_rule_registry', '_subject_name_func', diff --git a/tests/integration/schema_registry/_sync/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py index 7ea741856..54e458152 100644 --- a/tests/integration/schema_registry/_sync/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -92,7 +92,7 @@ def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = sr.get_schema(serializer._schema_id.id).references + registered_refs = (sr.get_schema(serializer._schema_id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() From 12b0ec3ec0978e230335a578c43144beca507209 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Tue, 27 May 2025 21:01:53 -0700 Subject: [PATCH 19/32] fix --- src/confluent_kafka/schema_registry/_async/protobuf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 7d20cde68..f072c527e 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -30,7 +30,7 @@ reference_subject_name_strategy, topic_subject_name_strategy, prefix_schema_id_serializer, dual_schema_id_deserializer) -from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE from confluent_kafka.schema_registry.rule_registry import RuleRegistry From f9a82b49e38ea31563eb109fe5cfc438353874ae Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 28 May 2025 12:45:25 -0700 Subject: [PATCH 20/32] add async changes and refactor into separate unasync script --- setup.py | 132 +--------------- .../schema_registry/_async/avro.py | 29 ++-- .../schema_registry/_async/json_schema.py | 7 +- .../schema_registry/_async/protobuf.py | 6 +- .../schema_registry/_async/serde.py | 19 ++- tools/unasync.py | 145 ++++++++++++++++++ 6 files changed, 176 insertions(+), 162 deletions(-) create mode 100644 tools/unasync.py diff --git a/setup.py b/setup.py index 91b0fb9f2..1141e7cee 100755 --- a/setup.py +++ b/setup.py @@ -1,13 +1,9 @@ #!/usr/bin/env python import os -import platform -import re -import sys from setuptools import setup from setuptools import Extension -from setuptools.command.build_py import build_py as _build_py -from pprint import pprint +import platform work_dir = os.path.dirname(os.path.realpath(__file__)) mod_dir = os.path.join(work_dir, 'src', 'confluent_kafka') @@ -29,128 +25,4 @@ os.path.join(ext_dir, 'AdminTypes.c'), os.path.join(ext_dir, 'Admin.c')]) -SUBS = [ - ('from confluent_kafka.schema_registry.common import asyncinit', ''), - ('@asyncinit', ''), - ('import asyncio', ''), - ('asyncio.sleep', 'time.sleep'), - - ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), - ('_Async([A-Z][A-Za-z0-9_]*)', r'_\2'), - ('async_([a-z][A-Za-z0-9_]*)', r'\2'), - - ('async def', 'def'), - ('await ', ''), - ('aclose', 'close'), - ('__aenter__', '__enter__'), - ('__aexit__', '__exit__'), - ('__aiter__', '__iter__'), -] -COMPILED_SUBS = [ - (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) - for regex, repl in SUBS -] - -USED_SUBS = set() - -def unasync_line(line): - for index, (regex, repl) in enumerate(COMPILED_SUBS): - old_line = line - line = re.sub(regex, repl, line) - if old_line != line: - USED_SUBS.add(index) - return line - - -def unasync_file(in_path, out_path): - with open(in_path, "r") as in_file: - with open(out_path, "w", newline="") as out_file: - for line in in_file.readlines(): - line = unasync_line(line) - out_file.write(line) - - -def unasync_file_check(in_path, out_path): - with open(in_path, "r") as in_file: - with open(out_path, "r") as out_file: - for in_line, out_line in zip(in_file.readlines(), out_file.readlines()): - expected = unasync_line(in_line) - if out_line != expected: - print(f'unasync mismatch between {in_path!r} and {out_path!r}') - print(f'Async code: {in_line!r}') - print(f'Expected sync code: {expected!r}') - print(f'Actual sync code: {out_line!r}') - sys.exit(1) - - -def unasync_dir(in_dir, out_dir, check_only=False): - for dirpath, dirnames, filenames in os.walk(in_dir): - for filename in filenames: - if not filename.endswith('.py'): - continue - rel_dir = os.path.relpath(dirpath, in_dir) - in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) - out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) - print(in_path, '->', out_path) - if check_only: - unasync_file_check(in_path, out_path) - else: - unasync_file(in_path, out_path) - -def unasync(): - unasync_dir( - "src/confluent_kafka/schema_registry/_async", - "src/confluent_kafka/schema_registry/_sync", - check_only=False - ) - unasync_dir( - "tests/integration/schema_registry/_async", - "tests/integration/schema_registry/_sync", - check_only=False - ) - - - if len(USED_SUBS) != len(SUBS): - unused_subs = [SUBS[i] for i in range(len(SUBS)) if i not in USED_SUBS] - - print("These patterns were not used:") - pprint(unused_subs) - - -class build_py(_build_py): - """ - Subclass build_py from setuptools to modify its behavior. - - Convert files in _async dir from being asynchronous to synchronous - and saves them to the specified output directory. - """ - - def run(self): - self._updated_files = [] - - # Base class code - if self.py_modules: - self.build_modules() - if self.packages: - self.build_packages() - self.build_package_data() - - # Our modification - unasync() - - # Remaining base class code - self.byte_compile(self.get_outputs(include_bytecode=0)) - - def build_module(self, module, module_file, package): - outfile, copied = super().build_module(module, module_file, package) - if copied: - self._updated_files.append(outfile) - return outfile, copied - - -setup( - ext_modules=[module], - cmdclass={ - 'build_py': build_py, - } -) +setup(ext_modules=[module]) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index c5f0d1dc3..42d3c8fca 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -16,7 +16,6 @@ # limitations under the License. from json import loads -from struct import pack, unpack from typing import Dict, Union, Optional, Callable from fastavro import schemaless_reader, schemaless_writer @@ -24,8 +23,7 @@ from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO, AVRO_TYPE -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, +from confluent_kafka.schema_registry import (Schema, topic_subject_name_strategy, RuleMode, AsyncSchemaRegistryClient, @@ -244,10 +242,10 @@ async def __init__( self._subject_name_func = conf_copy.pop('subject.name.strategy') if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - - self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') - if not callable(self._schema_id_deserializer): - raise ValueError("schema.id.deserializer must be callable") + + self._schema_id_serializer = conf_copy.pop('schema.id.serializer') + if not callable(self._schema_id_serializer): + raise ValueError("schema.id.serializer must be callable") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" @@ -344,8 +342,6 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 parsed_schema = self._parsed_schema with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) @@ -481,10 +477,9 @@ async def __init__( if not callable(self._subject_name_func): raise ValueError("subject.name.strategy must be callable") - self._schema_id_serializer = conf_copy.pop('schema.id.serializer') - if not callable(self._schema_id_serializer): - raise ValueError("schema.id.serializer must be callable") - + self._schema_id_deserializer = conf_copy.pop('schema.id.deserializer') + if not callable(self._schema_id_deserializer): + raise ValueError("schema.id.deserializer must be callable") if len(conf_copy) > 0: raise ValueError("Unrecognized properties: {}" @@ -547,7 +542,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = schema_id = SchemaId(AVRO_TYPE) payload = self._schema_id_deserializer(data, ctx, schema_id) - writer_schema_raw = self._get_writer_schema(schema_id, subject) + writer_schema_raw = await self._get_writer_schema(schema_id, subject) writer_schema = self._get_parsed_schema(writer_schema_raw) if subject is None: @@ -580,11 +575,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = reader_schema, self._return_record_name) - - - - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, obj_dict, get_inline_tags(reader_schema), diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index f03878b42..89968ae42 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -16,7 +16,6 @@ # limitations under the License. import json -import struct from typing import Union, Optional, Tuple, Callable from cachetools import LRUCache @@ -25,8 +24,7 @@ from jsonschema.validators import validator_for from referencing import Registry, Resource -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, +from confluent_kafka.schema_registry import (Schema, topic_subject_name_strategy, RuleMode, AsyncSchemaRegistryClient, prefix_schema_id_serializer, @@ -612,7 +610,8 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N reader_root_resource = Resource.from_contents( reader_schema, default_specification=DEFAULT_SPEC) reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, "$", message, field_transform)) obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index f072c527e..f6950eebf 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -16,7 +16,6 @@ # limitations under the License. import io -import struct import warnings from typing import Set, List, Union, Optional, Tuple @@ -26,8 +25,7 @@ from google.protobuf.message import DecodeError, Message from google.protobuf.message_factory import GetMessageClass -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - reference_subject_name_strategy, +from confluent_kafka.schema_registry import (reference_subject_name_strategy, topic_subject_name_strategy, prefix_schema_id_serializer, dual_schema_id_deserializer) from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient @@ -664,7 +662,7 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N except DecodeError as e: raise SerializationError(str(e)) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 transform(rule_ctx, reader_desc, message, field_transform)) msg = self._execute_rules(ctx, subject, RuleMode.READ, None, reader_schema_raw, msg, None, diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 7824740f6..0ff1ec3bc 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -22,7 +22,7 @@ from confluent_kafka.schema_registry import RegisteredSchema from confluent_kafka.schema_registry.common.serde import ErrorAction, \ FieldTransformer, Migration, NoneAction, RuleAction, \ - RuleConditionError, RuleContext, RuleError + RuleConditionError, RuleContext, RuleError, SchemaId from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ Rule, RuleKind, Schema, RuleSet from confluent_kafka.serialization import Serializer, Deserializer, \ @@ -180,12 +180,21 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA return self._rule_registry.get_action(action_name) -class AsyncBaseSerializer(AsyncBaseSerde, Serializer): - __slots__ = ['_auto_register', '_normalize_schemas'] +class BaseSerializer(AsyncBaseSerde, Serializer): + __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] -class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): - __slots__ = [] +class BaseDeserializer(AsyncBaseSerde, Deserializer): + __slots__ = ['_schema_id_deserializer'] + + async def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: + if schema_id.id is not None: + return await self._registry.get_schema(schema_id.id, subject, fmt) + elif schema_id.guid is not None: + return await self._registry.get_schema_by_guid(str(schema_id.guid), fmt) + else: + raise SerializationError("Schema ID or GUID is not set") def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: if rule_set is None: diff --git a/tools/unasync.py b/tools/unasync.py new file mode 100644 index 000000000..d922bf80d --- /dev/null +++ b/tools/unasync.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python + +import os +import re +import sys +import argparse +from pprint import pprint +import subprocess + +SUBS = [ + ('from confluent_kafka.schema_registry.common import asyncinit', ''), + ('@asyncinit', ''), + ('import asyncio', ''), + ('asyncio.sleep', 'time.sleep'), + + ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), + ('_Async([A-Z][A-Za-z0-9_]*)', r'_\2'), + ('async_([a-z][A-Za-z0-9_]*)', r'\2'), + + ('async def', 'def'), + ('await ', ''), + ('aclose', 'close'), + ('__aenter__', '__enter__'), + ('__aexit__', '__exit__'), + ('__aiter__', '__iter__'), +] + +COMPILED_SUBS = [ + (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) + for regex, repl in SUBS +] + +USED_SUBS = set() + +def unasync_line(line): + for index, (regex, repl) in enumerate(COMPILED_SUBS): + old_line = line + line = re.sub(regex, repl, line) + if old_line != line: + USED_SUBS.add(index) + return line + + +def unasync_file(in_path, out_path): + with open(in_path, "r") as in_file: + with open(out_path, "w", newline="") as out_file: + for line in in_file.readlines(): + line = unasync_line(line) + out_file.write(line) + + +def unasync_file_check(in_path, out_path): + with open(in_path, "r") as in_file: + with open(out_path, "r") as out_file: + for in_line, out_line in zip(in_file.readlines(), out_file.readlines()): + expected = unasync_line(in_line) + if out_line != expected: + print(f'unasync mismatch between {in_path!r} and {out_path!r}') + print(f'Async code: {in_line!r}') + print(f'Expected sync code: {expected!r}') + print(f'Actual sync code: {out_line!r}') + sys.exit(1) + + +def unasync_dir(in_dir, out_dir, check_only=False): + for dirpath, dirnames, filenames in os.walk(in_dir): + for filename in filenames: + if not filename.endswith('.py'): + continue + rel_dir = os.path.relpath(dirpath, in_dir) + in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) + out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) + print(in_path, '->', out_path) + if check_only: + unasync_file_check(in_path, out_path) + else: + unasync_file(in_path, out_path) + +def check_diff(sync_dir): + """Check if there are any differences in the sync directory. + Returns a list of files that have differences.""" + try: + # Get the list of files in the sync directory + result = subprocess.run(['git', 'ls-files', sync_dir], + capture_output=True, text=True) + if result.returncode != 0: + print(f"Error listing files in {sync_dir}") + return [] + + files = result.stdout.strip().split('\n') + if not files or (len(files) == 1 and not files[0]): + print(f"No files found in {sync_dir}") + return [] + + # Check if any of these files have differences + files_with_diff = [] + for file in files: + if not file: # Skip empty lines + continue + diff_result = subprocess.run(['git', 'diff', '--quiet', file], + capture_output=True, text=True) + if diff_result.returncode != 0: + files_with_diff.append(file) + return files_with_diff + except subprocess.CalledProcessError as e: + print(f"Error checking differences: {e}") + return [] + +def unasync(check=False): + async_dirs = [ + "src/confluent_kafka/schema_registry/_async", + "tests/integration/schema_registry/_async" + ] + sync_dirs = [ + "src/confluent_kafka/schema_registry/_sync", + "tests/integration/schema_registry/_sync" + ] + + print("Converting async code to sync code...") + for async_dir, sync_dir in zip(async_dirs, sync_dirs): + unasync_dir(async_dir, sync_dir, check_only=False) + + files_with_diff = [] + if check: + for sync_dir in sync_dirs: + files_with_diff.extend(check_diff(sync_dir)) + + if files_with_diff: + print("\n⚠️ Detected changes to a _sync directory that are uncommitted.") + print("\nFiles with differences:") + for file in files_with_diff: + print(f" - {file}") + print("\nPlease either:") + print("1. Commit the changes in the generated _sync files, or") + print("2. Revert the changes in the original _async files") + sys.exit(1) + else: + print("\n✅ Conversion completed successfully!") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Convert async code to sync code') + parser.add_argument('--check', action='store_true', + help='Exit with non-zero status if sync directory has any differences') + args = parser.parse_args() + unasync(check=args.check) From 11777862248e709185f309bd9e7f72be79f11383 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 28 May 2025 12:47:09 -0700 Subject: [PATCH 21/32] add check to source-package-verification --- tools/source-package-verification.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index a84e20c5a..19ae2243b 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -27,6 +27,9 @@ if [[ $RUN_COVERAGE == true ]]; then exit 0 fi +echo "Checking for uncommitted changes in generated _sync directories" +python3 tools/unasync.py --check + python3 -m pip install . if [[ $OS_NAME == linux && $ARCH == x64 ]]; then From b3ed4593538f90c7134d8d31428421f9f4e5e7fd Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 28 May 2025 14:13:12 -0700 Subject: [PATCH 22/32] fixes --- src/confluent_kafka/schema_registry/_async/avro.py | 2 +- src/confluent_kafka/schema_registry/_async/serde.py | 4 ++-- .../schema_registry/_async/test_proto_serializers.py | 2 +- .../schema_registry/_sync/test_proto_serializers.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 42d3c8fca..b92b9a600 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -543,7 +543,7 @@ async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = payload = self._schema_id_deserializer(data, ctx, schema_id) writer_schema_raw = await self._get_writer_schema(schema_id, subject) - writer_schema = self._get_parsed_schema(writer_schema_raw) + writer_schema = await self._get_parsed_schema(writer_schema_raw) if subject is None: subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 0ff1ec3bc..c7e8fe461 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -180,11 +180,11 @@ def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleA return self._rule_registry.get_action(action_name) -class BaseSerializer(AsyncBaseSerde, Serializer): +class AsyncBaseSerializer(AsyncBaseSerde, Serializer): __slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer'] -class BaseDeserializer(AsyncBaseSerde, Deserializer): +class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] async def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, diff --git a/tests/integration/schema_registry/_async/test_proto_serializers.py b/tests/integration/schema_registry/_async/test_proto_serializers.py index 5f0b9f413..6aa908aa5 100644 --- a/tests/integration/schema_registry/_async/test_proto_serializers.py +++ b/tests/integration/schema_registry/_async/test_proto_serializers.py @@ -92,7 +92,7 @@ async def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs await producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = (await sr.get_schema(serializer._schema_id)).references + registered_refs = (await sr.get_schema(serializer._schema_id.id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() diff --git a/tests/integration/schema_registry/_sync/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py index 54e458152..39d42e2a3 100644 --- a/tests/integration/schema_registry/_sync/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -92,7 +92,7 @@ def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = (sr.get_schema(serializer._schema_id)).references + registered_refs = (sr.get_schema(serializer._schema_id.id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() From 2fb4e3eaa28717bec2f0a2eb36eafb94dc411db6 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 28 May 2025 16:26:23 -0700 Subject: [PATCH 23/32] fix flake8 --- .../schema_registry/_async/avro.py | 7 +- .../schema_registry/_async/json_schema.py | 2 + .../schema_registry/_async/protobuf.py | 5 +- .../_async/schema_registry_client.py | 78 ++++++++++--------- .../schema_registry/_async/serde.py | 5 +- .../schema_registry/_sync/avro.py | 7 +- .../schema_registry/_sync/json_schema.py | 2 + .../schema_registry/_sync/protobuf.py | 5 +- .../_sync/schema_registry_client.py | 78 ++++++++++--------- .../schema_registry/_sync/serde.py | 5 +- tests/common/_async/consumer.py | 11 +-- tests/common/_async/producer.py | 40 +++++----- .../schema_registry/_async/test_api_client.py | 6 +- .../_async/test_avro_serializers.py | 77 +++++++++++------- .../_async/test_json_serializers.py | 9 ++- .../_async/test_proto_serializers.py | 9 ++- .../schema_registry/_sync/test_api_client.py | 6 +- .../_sync/test_avro_serializers.py | 77 +++++++++++------- .../_sync/test_json_serializers.py | 9 ++- .../_sync/test_proto_serializers.py | 9 ++- tools/unasync.py | 21 ++--- tox.ini | 4 + 22 files changed, 284 insertions(+), 188 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index b92b9a600..c7a523fbe 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -32,7 +32,8 @@ from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ @@ -62,6 +63,7 @@ async def _resolve_named_schema( named_schemas[ref.name] = parsed_schema return named_schemas + @asyncinit class AsyncAvroSerializer(AsyncBaseSerializer): """ @@ -506,7 +508,8 @@ async def __init__( def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) - async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + async def __deserialize( + self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index 89968ae42..c57c6b9a0 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -70,6 +70,7 @@ async def _resolve_named_schema( ref_registry = ref_registry.with_resource(ref.name, resource) return ref_registry + @asyncinit class AsyncJSONSerializer(AsyncBaseSerializer): """ @@ -400,6 +401,7 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re self._validators[schema] = validator return validator + @asyncinit class AsyncJSONDeserializer(AsyncBaseDeserializer): """ diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index f6950eebf..ff060883c 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -38,7 +38,8 @@ from confluent_kafka.serialization import SerializationError, \ SerializationContext from confluent_kafka.schema_registry.common import asyncinit -from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ '_resolve_named_schema', @@ -72,6 +73,7 @@ async def _resolve_named_schema( file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) pool.Add(file_descriptor_proto) + @asyncinit class AsyncProtobufSerializer(AsyncBaseSerializer): """ @@ -457,6 +459,7 @@ async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileD self._parsed_schemas.set(schema, (fd_proto, pool)) return fd_proto, pool + @asyncinit class AsyncProtobufDeserializer(AsyncBaseDeserializer): """ diff --git a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py index e5f929beb..f7e22d039 100644 --- a/src/confluent_kafka/schema_registry/_async/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_async/schema_registry_client.py @@ -283,10 +283,11 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + str(type(self.token_endpoint))) - self.bearer_field_provider = _AsyncOAuthClient(self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) + self.bearer_field_provider = _AsyncOAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") @@ -761,9 +762,11 @@ async def lookup_schema( request = schema.to_dict() - response = await self._rest_client.post('subjects/{}?normalize={}&deleted={}' - .format(_urlencode(subject_name), normalize_schemas, deleted), - body=request) + response = await self._rest_client.post( + 'subjects/{}?normalize={}&deleted={}'.format( + _urlencode(subject_name), normalize_schemas, deleted), + body=request + ) result = RegisteredSchema.from_dict(response) @@ -817,12 +820,14 @@ async def delete_subject(self, subject_name: str, permanent: bool = False) -> Li """ # noqa: E501 if permanent: - versions = await self._rest_client.delete('subjects/{}?permanent=true' - .format(_urlencode(subject_name))) + versions = await self._rest_client.delete( + 'subjects/{}?permanent=true'.format(_urlencode(subject_name)) + ) self._cache.remove_by_subject(subject_name) else: - versions = await self._rest_client.delete('subjects/{}' - .format(_urlencode(subject_name))) + versions = await self._rest_client.delete( + 'subjects/{}'.format(_urlencode(subject_name)) + ) return versions @@ -851,9 +856,9 @@ async def get_latest_version( return registered_schema query = {'format': fmt} if fmt is not None else None - response = await self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - 'latest'), query) + response = await self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -892,8 +897,9 @@ async def get_latest_with_metadata( query['key'] = [_urlencode(key) for key in keys] query['value'] = [_urlencode(metadata[key]) for key in keys] - response = await self._rest_client.get('subjects/{}/metadata' - .format(_urlencode(subject_name)), query) + response = await self._rest_client.get( + 'subjects/{}/metadata'.format(_urlencode(subject_name)), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -929,9 +935,9 @@ async def get_version( return registered_schema query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = await self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version), query) + response = await self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -980,14 +986,14 @@ async def delete_version(self, subject_name: str, version: int, permanent: bool """ # noqa: E501 if permanent: - response = await self._rest_client.delete('subjects/{}/versions/{}?permanent=true' - .format(_urlencode(subject_name), - version)) + response = await self._rest_client.delete( + 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) + ) self._cache.remove_by_subject_version(subject_name, version) else: - response = await self._rest_client.delete('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version)) + response = await self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) return response @@ -1016,12 +1022,13 @@ async def set_compatibility(self, subject_name: Optional[str] = None, level: Opt raise ValueError("level must be set") if subject_name is None: - return await self._rest_client.put('config', - body={'compatibility': level.upper()}) + return await self._rest_client.put( + 'config', body={'compatibility': level.upper()} + ) - return await self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body={'compatibility': level.upper()}) + return await self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body={'compatibility': level.upper()} + ) async def get_compatibility(self, subject_name: Optional[str] = None) -> str: """ @@ -1106,12 +1113,13 @@ async def set_config( raise ValueError("config must be set") if subject_name is None: - return await self._rest_client.put('config', - body=config.to_dict()) + return await self._rest_client.put( + 'config', body=config.to_dict() + ) - return await self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body=config.to_dict()) + return await self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body=config.to_dict() + ) async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': """ diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index c7e8fe461..40afca25e 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -187,8 +187,9 @@ class AsyncBaseSerializer(AsyncBaseSerde, Serializer): class AsyncBaseDeserializer(AsyncBaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] - async def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, - fmt: Optional[str] = None) -> Schema: + async def _get_writer_schema( + self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: if schema_id.id is not None: return await self._registry.get_schema(schema_id.id, subject, fmt) elif schema_id.guid is not None: diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index 147dd2a34..b069b8c63 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -32,7 +32,8 @@ from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ @@ -63,6 +64,7 @@ def _resolve_named_schema( return named_schemas + class AvroSerializer(BaseSerializer): """ Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. @@ -506,7 +508,8 @@ def __init__( def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: return self.__deserialize(data, ctx) - def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + def __deserialize( + self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: """ Deserialize Avro binary encoded data with Confluent Schema Registry framing to a dict, or object instance according to from_dict, if specified. diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 61c46cd96..b8520b814 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -71,6 +71,7 @@ def _resolve_named_schema( return ref_registry + class JSONSerializer(BaseSerializer): """ Serializer that outputs JSON encoded data with Confluent Schema Registry framing. @@ -401,6 +402,7 @@ def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Re return validator + class JSONDeserializer(BaseDeserializer): """ Deserializer for JSON encoded data with Confluent Schema Registry diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 7d4070465..b1a7cf0b0 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -38,7 +38,8 @@ from confluent_kafka.serialization import SerializationError, \ SerializationContext -from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache, SchemaId +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache, SchemaId __all__ = [ '_resolve_named_schema', @@ -73,6 +74,7 @@ def _resolve_named_schema( pool.Add(file_descriptor_proto) + class ProtobufSerializer(BaseSerializer): """ Serializer for Protobuf Message derived classes. Serialization format is Protobuf, @@ -458,6 +460,7 @@ def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescrip return fd_proto, pool + class ProtobufDeserializer(BaseDeserializer): """ Deserializer for Protobuf serialized data with Confluent Schema Registry framing. diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py index 65d700f83..f2caffc97 100644 --- a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -283,10 +283,11 @@ def __init__(self, conf: dict): raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + str(type(self.token_endpoint))) - self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) + self.bearer_field_provider = _OAuthClient( + self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': if 'bearer.auth.token' not in conf_copy: raise ValueError("Missing bearer.auth.token") @@ -761,9 +762,11 @@ def lookup_schema( request = schema.to_dict() - response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' - .format(_urlencode(subject_name), normalize_schemas, deleted), - body=request) + response = self._rest_client.post( + 'subjects/{}?normalize={}&deleted={}'.format( + _urlencode(subject_name), normalize_schemas, deleted), + body=request + ) result = RegisteredSchema.from_dict(response) @@ -817,12 +820,14 @@ def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int """ # noqa: E501 if permanent: - versions = self._rest_client.delete('subjects/{}?permanent=true' - .format(_urlencode(subject_name))) + versions = self._rest_client.delete( + 'subjects/{}?permanent=true'.format(_urlencode(subject_name)) + ) self._cache.remove_by_subject(subject_name) else: - versions = self._rest_client.delete('subjects/{}' - .format(_urlencode(subject_name))) + versions = self._rest_client.delete( + 'subjects/{}'.format(_urlencode(subject_name)) + ) return versions @@ -851,9 +856,9 @@ def get_latest_version( return registered_schema query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - 'latest'), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), 'latest'), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -892,8 +897,9 @@ def get_latest_with_metadata( query['key'] = [_urlencode(key) for key in keys] query['value'] = [_urlencode(metadata[key]) for key in keys] - response = self._rest_client.get('subjects/{}/metadata' - .format(_urlencode(subject_name)), query) + response = self._rest_client.get( + 'subjects/{}/metadata'.format(_urlencode(subject_name)), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -929,9 +935,9 @@ def get_version( return registered_schema query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version), query) + response = self._rest_client.get( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version), query + ) registered_schema = RegisteredSchema.from_dict(response) @@ -980,14 +986,14 @@ def delete_version(self, subject_name: str, version: int, permanent: bool = Fals """ # noqa: E501 if permanent: - response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' - .format(_urlencode(subject_name), - version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}?permanent=true'.format(_urlencode(subject_name), version) + ) self._cache.remove_by_subject_version(subject_name, version) else: - response = self._rest_client.delete('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version)) + response = self._rest_client.delete( + 'subjects/{}/versions/{}'.format(_urlencode(subject_name), version) + ) return response @@ -1016,12 +1022,13 @@ def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[ raise ValueError("level must be set") if subject_name is None: - return self._rest_client.put('config', - body={'compatibility': level.upper()}) + return self._rest_client.put( + 'config', body={'compatibility': level.upper()} + ) - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body={'compatibility': level.upper()}) + return self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body={'compatibility': level.upper()} + ) def get_compatibility(self, subject_name: Optional[str] = None) -> str: """ @@ -1106,12 +1113,13 @@ def set_config( raise ValueError("config must be set") if subject_name is None: - return self._rest_client.put('config', - body=config.to_dict()) + return self._rest_client.put( + 'config', body=config.to_dict() + ) - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body=config.to_dict()) + return self._rest_client.put( + 'config/{}'.format(_urlencode(subject_name)), body=config.to_dict() + ) def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': """ diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index 84ce15726..f0512f872 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -187,8 +187,9 @@ class BaseSerializer(BaseSerde, Serializer): class BaseDeserializer(BaseSerde, Deserializer): __slots__ = ['_schema_id_deserializer'] - def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None, - fmt: Optional[str] = None) -> Schema: + def _get_writer_schema( + self, schema_id: SchemaId, subject: Optional[str] = None, + fmt: Optional[str] = None) -> Schema: if schema_id.id is not None: return self._registry.get_schema(schema_id.id, subject, fmt) elif schema_id.guid is not None: diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py index bb74c7b51..d4ece6723 100644 --- a/tests/common/_async/consumer.py +++ b/tests/common/_async/consumer.py @@ -25,13 +25,14 @@ ASYNC_CONSUMER_POLL_INTERVAL_SECONDS: int = 0.2 ASYNC_CONSUMER_POLL_INFINITE_TIMEOUT_SECONDS: int = -1 + class AsyncConsumer(Consumer): def __init__( - self, - conf: dict, - loop: asyncio.AbstractEventLoop = None, - poll_interval_seconds: int = ASYNC_CONSUMER_POLL_INTERVAL_SECONDS - ): + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval_seconds: int = ASYNC_CONSUMER_POLL_INTERVAL_SECONDS + ): super().__init__(conf) self._loop = loop or asyncio.get_event_loop() diff --git a/tests/common/_async/producer.py b/tests/common/_async/producer.py index e9f811700..882685df2 100644 --- a/tests/common/_async/producer.py +++ b/tests/common/_async/producer.py @@ -25,13 +25,14 @@ ASYNC_PRODUCER_POLL_INTERVAL: int = 0.2 + class AsyncProducer(Producer): def __init__( - self, - conf: dict, - loop: asyncio.AbstractEventLoop = None, - poll_interval: int = ASYNC_PRODUCER_POLL_INTERVAL - ): + self, + conf: dict, + loop: asyncio.AbstractEventLoop = None, + poll_interval: int = ASYNC_PRODUCER_POLL_INTERVAL + ): super().__init__(conf) self._loop = loop or asyncio.get_event_loop() @@ -66,12 +67,12 @@ def wrapped_on_delivery(err, msg): self._loop.call_soon_threadsafe(fut.set_result, msg) super().produce( - topic, - value, - key, - headers=headers, - partition=partition, - timestamp=timestamp, + topic, + value, + key, + headers=headers, + partition=partition, + timestamp=timestamp, on_delivery=wrapped_on_delivery ) return await fut @@ -94,8 +95,9 @@ def __init__(self, conf): super(TestAsyncSerializingProducer, self).__init__(conf_copy) - async def produce(self, topic, key=None, value=None, partition=-1, - on_delivery=None, timestamp=0, headers=None): + async def produce( + self, topic, key=None, value=None, partition=-1, + on_delivery=None, timestamp=0, headers=None): ctx = SerializationContext(topic, MessageField.KEY, headers) if self._key_serializer is not None: try: @@ -109,8 +111,10 @@ async def produce(self, topic, key=None, value=None, partition=-1, except Exception as se: raise ValueSerializationError(se) - return await super().produce(topic, value, key, - headers=headers, - partition=partition, - timestamp=timestamp, - on_delivery=on_delivery) + return await super().produce( + topic, value, key, + headers=headers, + partition=partition, + timestamp=timestamp, + on_delivery=on_delivery + ) diff --git a/tests/integration/schema_registry/_async/test_api_client.py b/tests/integration/schema_registry/_async/test_api_client.py index 244e1c4b1..7ba3ee6c9 100644 --- a/tests/integration/schema_registry/_async/test_api_client.py +++ b/tests/integration/schema_registry/_async/test_api_client.py @@ -438,8 +438,10 @@ async def test_api_subject_config_update(kafka_cluster, load_file): subject = str(uuid1()) await sr.register_schema(subject, schema) - await sr.set_compatibility(subject_name=subject, - level="FULL_TRANSITIVE") + await sr.set_compatibility( + subject_name=subject, + level="FULL_TRANSITIVE" + ) assert await sr.get_compatibility(subject_name=subject) == "FULL_TRANSITIVE" diff --git a/tests/integration/schema_registry/_async/test_avro_serializers.py b/tests/integration/schema_registry/_async/test_avro_serializers.py index 966bbc14d..0cdf97011 100644 --- a/tests/integration/schema_registry/_async/test_avro_serializers.py +++ b/tests/integration/schema_registry/_async/test_avro_serializers.py @@ -154,27 +154,41 @@ async def _references_test_common(kafka_cluster, awarded_user, serializer_schema topic = kafka_cluster.create_topic_and_wait_propogation("reference-avro") sr = kafka_cluster.async_schema_registry() - value_serializer = await AsyncAvroSerializer(sr, serializer_schema, - lambda user, ctx: - dict(award=dict(name=user.award.name, - properties=dict(year=user.award.properties.year, - points=user.award.properties.points)), - user=dict(name=user.user.name, - favorite_number=user.user.favorite_number, - favorite_color=user.user.favorite_color))) + value_serializer = await AsyncAvroSerializer( + sr, + serializer_schema, + lambda user, ctx: dict( + award=dict( + name=user.award.name, + properties=dict(year=user.award.properties.year, points=user.award.properties.points) + ), + user=dict( + name=user.user.name, + favorite_number=user.user.favorite_number, + favorite_color=user.user.favorite_color + ) + ) + ) value_deserializer = \ - await AsyncAvroDeserializer(sr, deserializer_schema, - lambda user, ctx: - AwardedUser(award=Award(name=user.get('award').get('name'), - properties=AwardProperties( - year=user.get('award').get('properties').get( - 'year'), - points=user.get('award').get('properties').get( - 'points'))), - user=User(name=user.get('user').get('name'), - favorite_number=user.get('user').get('favorite_number'), - favorite_color=user.get('user').get('favorite_color')))) + await AsyncAvroDeserializer( + sr, + deserializer_schema, + lambda user, ctx: AwardedUser( + award=Award( + name=user.get('award').get('name'), + properties=AwardProperties( + year=user.get('award').get('properties').get('year'), + points=user.get('award').get('properties').get('points') + ) + ), + user=User( + name=user.get('user').get('name'), + favorite_number=user.get('user').get('favorite_number'), + favorite_color=user.get('user').get('favorite_color') + ) + ) + ) producer = kafka_cluster.async_producer(value_serializer=value_serializer) @@ -304,15 +318,22 @@ async def test_avro_record_serialization_custom(kafka_cluster): sr = kafka_cluster.async_schema_registry() user = User('Bowie', 47, 'purple') - value_serializer = await AsyncAvroSerializer(sr, User.schema_str, - lambda user, ctx: - dict(name=user.name, - favorite_number=user.favorite_number, - favorite_color=user.favorite_color)) - - value_deserializer = await AsyncAvroDeserializer(sr, User.schema_str, - lambda user_dict, ctx: - User(**user_dict)) + value_serializer = await AsyncAvroSerializer( + sr, + User.schema_str, + lambda user, ctx: + dict( + name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color + ) + ) + + value_deserializer = await AsyncAvroDeserializer( + sr, + User.schema_str, + lambda user_dict, ctx: User(**user_dict) + ) producer = kafka_cluster.async_producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_async/test_json_serializers.py b/tests/integration/schema_registry/_async/test_json_serializers.py index 8f45e0aaf..464b41836 100644 --- a/tests/integration/schema_registry/_async/test_json_serializers.py +++ b/tests/integration/schema_registry/_async/test_json_serializers.py @@ -336,10 +336,11 @@ async def test_json_record_serialization_custom(kafka_cluster, load_file): sr = kafka_cluster.async_schema_registry() schema_str = load_file("product.json") - value_serializer = await AsyncJSONSerializer(schema_str, sr, - to_dict=_testProduct_to_dict) - value_deserializer = await AsyncJSONDeserializer(schema_str, - from_dict=_testProduct_from_dict) + value_serializer = await AsyncJSONSerializer(schema_str, sr, to_dict=_testProduct_to_dict) + value_deserializer = await AsyncJSONDeserializer( + schema_str, + from_dict=_testProduct_from_dict + ) producer = kafka_cluster.async_producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_async/test_proto_serializers.py b/tests/integration/schema_registry/_async/test_proto_serializers.py index 6aa908aa5..0e65686e2 100644 --- a/tests/integration/schema_registry/_async/test_proto_serializers.py +++ b/tests/integration/schema_registry/_async/test_proto_serializers.py @@ -138,10 +138,11 @@ async def test_protobuf_deserializer_type_mismatch(kafka_cluster): def dr(err, msg): print("dr msg {} {}".format(msg.key(), msg.value())) - await producer.produce(topic, key=pb2_1(test_string='abc', - test_bool=True, - test_bytes=b'def'), - partition=0) + await producer.produce( + topic, + key=pb2_1(test_string='abc', test_bool=True, test_bytes=b'def'), + partition=0 + ) producer.flush() with pytest.raises(ConsumeError) as e: diff --git a/tests/integration/schema_registry/_sync/test_api_client.py b/tests/integration/schema_registry/_sync/test_api_client.py index 6b75e3b52..d841eb174 100644 --- a/tests/integration/schema_registry/_sync/test_api_client.py +++ b/tests/integration/schema_registry/_sync/test_api_client.py @@ -438,8 +438,10 @@ def test_api_subject_config_update(kafka_cluster, load_file): subject = str(uuid1()) sr.register_schema(subject, schema) - sr.set_compatibility(subject_name=subject, - level="FULL_TRANSITIVE") + sr.set_compatibility( + subject_name=subject, + level="FULL_TRANSITIVE" + ) assert sr.get_compatibility(subject_name=subject) == "FULL_TRANSITIVE" diff --git a/tests/integration/schema_registry/_sync/test_avro_serializers.py b/tests/integration/schema_registry/_sync/test_avro_serializers.py index 322637ae7..b40426080 100644 --- a/tests/integration/schema_registry/_sync/test_avro_serializers.py +++ b/tests/integration/schema_registry/_sync/test_avro_serializers.py @@ -154,27 +154,41 @@ def _references_test_common(kafka_cluster, awarded_user, serializer_schema, dese topic = kafka_cluster.create_topic_and_wait_propogation("reference-avro") sr = kafka_cluster.schema_registry() - value_serializer = AvroSerializer(sr, serializer_schema, - lambda user, ctx: - dict(award=dict(name=user.award.name, - properties=dict(year=user.award.properties.year, - points=user.award.properties.points)), - user=dict(name=user.user.name, - favorite_number=user.user.favorite_number, - favorite_color=user.user.favorite_color))) + value_serializer = AvroSerializer( + sr, + serializer_schema, + lambda user, ctx: dict( + award=dict( + name=user.award.name, + properties=dict(year=user.award.properties.year, points=user.award.properties.points) + ), + user=dict( + name=user.user.name, + favorite_number=user.user.favorite_number, + favorite_color=user.user.favorite_color + ) + ) + ) value_deserializer = \ - AvroDeserializer(sr, deserializer_schema, - lambda user, ctx: - AwardedUser(award=Award(name=user.get('award').get('name'), - properties=AwardProperties( - year=user.get('award').get('properties').get( - 'year'), - points=user.get('award').get('properties').get( - 'points'))), - user=User(name=user.get('user').get('name'), - favorite_number=user.get('user').get('favorite_number'), - favorite_color=user.get('user').get('favorite_color')))) + AvroDeserializer( + sr, + deserializer_schema, + lambda user, ctx: AwardedUser( + award=Award( + name=user.get('award').get('name'), + properties=AwardProperties( + year=user.get('award').get('properties').get('year'), + points=user.get('award').get('properties').get('points') + ) + ), + user=User( + name=user.get('user').get('name'), + favorite_number=user.get('user').get('favorite_number'), + favorite_color=user.get('user').get('favorite_color') + ) + ) + ) producer = kafka_cluster.producer(value_serializer=value_serializer) @@ -304,15 +318,22 @@ def test_avro_record_serialization_custom(kafka_cluster): sr = kafka_cluster.schema_registry() user = User('Bowie', 47, 'purple') - value_serializer = AvroSerializer(sr, User.schema_str, - lambda user, ctx: - dict(name=user.name, - favorite_number=user.favorite_number, - favorite_color=user.favorite_color)) - - value_deserializer = AvroDeserializer(sr, User.schema_str, - lambda user_dict, ctx: - User(**user_dict)) + value_serializer = AvroSerializer( + sr, + User.schema_str, + lambda user, ctx: + dict( + name=user.name, + favorite_number=user.favorite_number, + favorite_color=user.favorite_color + ) + ) + + value_deserializer = AvroDeserializer( + sr, + User.schema_str, + lambda user_dict, ctx: User(**user_dict) + ) producer = kafka_cluster.producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_sync/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py index ae67c30f2..0c4a2545d 100644 --- a/tests/integration/schema_registry/_sync/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -336,10 +336,11 @@ def test_json_record_serialization_custom(kafka_cluster, load_file): sr = kafka_cluster.schema_registry() schema_str = load_file("product.json") - value_serializer = JSONSerializer(schema_str, sr, - to_dict=_testProduct_to_dict) - value_deserializer = JSONDeserializer(schema_str, - from_dict=_testProduct_from_dict) + value_serializer = JSONSerializer(schema_str, sr, to_dict=_testProduct_to_dict) + value_deserializer = JSONDeserializer( + schema_str, + from_dict=_testProduct_from_dict + ) producer = kafka_cluster.producer(value_serializer=value_serializer) diff --git a/tests/integration/schema_registry/_sync/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py index 39d42e2a3..9b3ca3197 100644 --- a/tests/integration/schema_registry/_sync/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -138,10 +138,11 @@ def test_protobuf_deserializer_type_mismatch(kafka_cluster): def dr(err, msg): print("dr msg {} {}".format(msg.key(), msg.value())) - producer.produce(topic, key=pb2_1(test_string='abc', - test_bool=True, - test_bytes=b'def'), - partition=0) + producer.produce( + topic, + key=pb2_1(test_string='abc', test_bool=True, test_bytes=b'def'), + partition=0 + ) producer.flush() with pytest.raises(ConsumeError) as e: diff --git a/tools/unasync.py b/tools/unasync.py index d922bf80d..61dc742d4 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -4,7 +4,6 @@ import re import sys import argparse -from pprint import pprint import subprocess SUBS = [ @@ -22,7 +21,7 @@ ('aclose', 'close'), ('__aenter__', '__enter__'), ('__aexit__', '__exit__'), - ('__aiter__', '__iter__'), + ('__aiter__', '__iter__'), ] COMPILED_SUBS = [ @@ -32,6 +31,7 @@ USED_SUBS = set() + def unasync_line(line): for index, (regex, repl) in enumerate(COMPILED_SUBS): old_line = line @@ -76,17 +76,17 @@ def unasync_dir(in_dir, out_dir, check_only=False): else: unasync_file(in_path, out_path) + def check_diff(sync_dir): """Check if there are any differences in the sync directory. Returns a list of files that have differences.""" try: # Get the list of files in the sync directory - result = subprocess.run(['git', 'ls-files', sync_dir], - capture_output=True, text=True) + result = subprocess.run(['git', 'ls-files', sync_dir], capture_output=True, text=True) if result.returncode != 0: print(f"Error listing files in {sync_dir}") return [] - + files = result.stdout.strip().split('\n') if not files or (len(files) == 1 and not files[0]): print(f"No files found in {sync_dir}") @@ -97,8 +97,7 @@ def check_diff(sync_dir): for file in files: if not file: # Skip empty lines continue - diff_result = subprocess.run(['git', 'diff', '--quiet', file], - capture_output=True, text=True) + diff_result = subprocess.run(['git', 'diff', '--quiet', file], capture_output=True, text=True) if diff_result.returncode != 0: files_with_diff.append(file) return files_with_diff @@ -106,6 +105,7 @@ def check_diff(sync_dir): print(f"Error checking differences: {e}") return [] + def unasync(check=False): async_dirs = [ "src/confluent_kafka/schema_registry/_async", @@ -137,9 +137,12 @@ def unasync(check=False): else: print("\n✅ Conversion completed successfully!") + if __name__ == '__main__': parser = argparse.ArgumentParser(description='Convert async code to sync code') - parser.add_argument('--check', action='store_true', - help='Exit with non-zero status if sync directory has any differences') + parser.add_argument( + '--check', + action='store_true', + help='Exit with non-zero status if sync directory has any differences') args = parser.parse_args() unasync(check=args.check) diff --git a/tox.ini b/tox.ini index 55c786e54..c646479aa 100644 --- a/tox.ini +++ b/tox.ini @@ -29,3 +29,7 @@ norecursedirs = tests/integration/*/java exclude = venv*,.venv*,env,.env,.tox,.toxenv,.git,build,docs,tools,tmp-build,*_pb2.py,*tmp-KafkaCluster/* max-line-length = 119 accept-encodings = utf-8 +per-file-ignores = + ./src/confluent_kafka/schema_registry/_sync/avro.py: E303 + ./src/confluent_kafka/schema_registry/_sync/json_schema.py: E303 + ./src/confluent_kafka/schema_registry/_sync/protobuf.py: E303 From e9f35445ebf82d00f27d013b96037742b1eedd08 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Wed, 28 May 2025 16:31:23 -0700 Subject: [PATCH 24/32] ignore build --- tools/source-package-verification.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/source-package-verification.sh b/tools/source-package-verification.sh index 19ae2243b..9d3662337 100755 --- a/tools/source-package-verification.sh +++ b/tools/source-package-verification.sh @@ -36,7 +36,7 @@ if [[ $OS_NAME == linux && $ARCH == x64 ]]; then if [[ -z $TEST_CONSUMER_GROUP_PROTOCOL ]]; then # Run these actions and tests only in this case echo "Building documentation ..." - flake8 --exclude ./_venv,*_pb2.py + flake8 --exclude ./_venv,*_pb2.py,./build pip install -r requirements/requirements-docs.txt make docs From 09de81ce19fb6b6ef3033b4bb318a571c381843d Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 10:00:32 -0700 Subject: [PATCH 25/32] fix tests --- requirements/requirements-tests.txt | 1 + tests/common/_async/__init__.py | 0 tests/common/_async/consumer.py | 6 +++++- tests/integration/integration_test.py | 3 ++- 4 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 tests/common/_async/__init__.py diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index 932f9a427..6ed6d8f34 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -10,3 +10,4 @@ respx pytest_cov pluggy<1.6.0 pytest-asyncio +async-timeout diff --git a/tests/common/_async/__init__.py b/tests/common/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py index d4ece6723..6f5dea21b 100644 --- a/tests/common/_async/consumer.py +++ b/tests/common/_async/consumer.py @@ -15,8 +15,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import sys import asyncio +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout from confluent_kafka.cimpl import Consumer from confluent_kafka.error import ConsumeError, KeyDeserializationError, ValueDeserializationError diff --git a/tests/integration/integration_test.py b/tests/integration/integration_test.py index d291c8751..258c6f5a8 100755 --- a/tests/integration/integration_test.py +++ b/tests/integration/integration_test.py @@ -28,6 +28,7 @@ import gc import struct import re +import pytest import confluent_kafka @@ -212,7 +213,7 @@ def verify_producer(): # Global variable to track garbage collection of suppressed on_delivery callbacks DrOnlyTestSuccess_gced = 0 - +@pytest.mark.skip(reason="This module must be run as a script") def test_producer_dr_only_error(): """ The C delivery.report.only.error configuration property From 7fe4e3fd072f8724d24f8480a42f7048d3e170d8 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 10:01:23 -0700 Subject: [PATCH 26/32] fix --- tests/integration/integration_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/integration_test.py b/tests/integration/integration_test.py index 258c6f5a8..39e368890 100755 --- a/tests/integration/integration_test.py +++ b/tests/integration/integration_test.py @@ -213,7 +213,7 @@ def verify_producer(): # Global variable to track garbage collection of suppressed on_delivery callbacks DrOnlyTestSuccess_gced = 0 -@pytest.mark.skip(reason="This module must be run as a script") +@pytest.mark.skip(reason="This test must be run as a standalone script") def test_producer_dr_only_error(): """ The C delivery.report.only.error configuration property From dac2befe89e1f3e134aca2a78815edf5112b035c Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 15:03:37 -0700 Subject: [PATCH 27/32] fix flake8 --- tests/common/_async/consumer.py | 14 +++++++------- tests/integration/integration_test.py | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py index 6f5dea21b..0fc5b0233 100644 --- a/tests/common/_async/consumer.py +++ b/tests/common/_async/consumer.py @@ -20,7 +20,7 @@ if sys.version_info >= (3, 11): from asyncio import timeout else: - from async_timeout import timeout + from async_timeout import timeout # noqa: F401 from confluent_kafka.cimpl import Consumer from confluent_kafka.error import ConsumeError, KeyDeserializationError, ValueDeserializationError @@ -48,9 +48,9 @@ def __aiter__(self): async def __anext__(self): return await self.poll(None) - async def poll(self, timeout: int = -1): - timeout = None if timeout == -1 else timeout - async with asyncio.timeout(timeout): + async def poll(self, poll_timeout: int = -1): + poll_timeout = None if poll_timeout == -1 else poll_timeout + async with asyncio.timeout(poll_timeout): while True: # Zero timeout here is what makes it non-blocking msg = super().poll(0) @@ -67,8 +67,8 @@ def __init__(self, conf): self._value_deserializer = conf_copy.pop('value.deserializer', None) super().__init__(conf_copy) - async def poll(self, timeout=-1): - msg = await super().poll(timeout) + async def poll(self, poll_timeout=-1): + msg = await super().poll(poll_timeout) if msg is None: return None @@ -96,7 +96,7 @@ async def poll(self, timeout=-1): msg.set_value(value) return msg - def consume(self, num_messages=1, timeout=-1): + def consume(self, num_messages=1, consume_timeout=-1): """ :py:func:`Consumer.consume` not implemented, use :py:func:`DeserializingConsumer.poll` instead diff --git a/tests/integration/integration_test.py b/tests/integration/integration_test.py index 39e368890..c7e44e021 100755 --- a/tests/integration/integration_test.py +++ b/tests/integration/integration_test.py @@ -213,6 +213,7 @@ def verify_producer(): # Global variable to track garbage collection of suppressed on_delivery callbacks DrOnlyTestSuccess_gced = 0 + @pytest.mark.skip(reason="This test must be run as a standalone script") def test_producer_dr_only_error(): """ From a8d05cb909a539bdb65241153ae84a725f98809c Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 20:53:44 -0700 Subject: [PATCH 28/32] add async sr import --- src/confluent_kafka/schema_registry/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index 25f21db7f..c19e0f46f 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -31,6 +31,7 @@ RuleSet, Schema, SchemaRegistryClient, + AsyncSchemaRegistryClient, SchemaRegistryError, SchemaReference, ServerConfig From 02ce9309034c2534800fed305006e1bbc3e46694 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 21:08:27 -0700 Subject: [PATCH 29/32] Auto-generate README --- DEVELOPER.md | 9 +++++ .../schema_registry/_sync/README.md | 7 ++++ .../schema_registry/_sync/README.md | 7 ++++ tools/unasync.py | 40 ++++++++++++++----- 4 files changed, 53 insertions(+), 10 deletions(-) create mode 100644 src/confluent_kafka/schema_registry/_sync/README.md create mode 100644 tests/integration/schema_registry/_sync/README.md diff --git a/DEVELOPER.md b/DEVELOPER.md index 8a50c7bf9..e886fa63a 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -31,6 +31,15 @@ or: Documentation will be generated in `build/sphinx/html`. +## Unasync -- maintaining sync versions of async code + + $ python tools/unasync.py + + # Run the script with the --check flag to ensure the sync code is up to date + $ python tools/unasync.py --check + +If you make any changes to the async code (in `src/confluent_kafka/schema_registry/_async` and `tests/integration/schema_registry/_async`), you **must** run this script to generate the sync counter parts (in `src/confluent_kafka/schema_registry/_sync` and `tests/integration/schema_registry/_sync`). Otherwise, this script will be run in CI with the --check flag and fail the build. + ## Tests diff --git a/src/confluent_kafka/schema_registry/_sync/README.md b/src/confluent_kafka/schema_registry/_sync/README.md new file mode 100644 index 000000000..905a46481 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/README.md @@ -0,0 +1,7 @@ +# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory diff --git a/tests/integration/schema_registry/_sync/README.md b/tests/integration/schema_registry/_sync/README.md new file mode 100644 index 000000000..905a46481 --- /dev/null +++ b/tests/integration/schema_registry/_sync/README.md @@ -0,0 +1,7 @@ +# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory diff --git a/tools/unasync.py b/tools/unasync.py index 61dc742d4..ff5c5fe19 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -6,6 +6,15 @@ import argparse import subprocess +# List of directories to convert from async to sync +# Each tuple contains the async directory and its sync counterpart +# If you add a new _async directory and want the _sync directory to be +# generated, you must add it to this list. +ASYNC_TO_SYNC = [ + ("src/confluent_kafka/schema_registry/_async", "src/confluent_kafka/schema_registry/_sync"), + ("tests/integration/schema_registry/_async", "tests/integration/schema_registry/_sync") +] + SUBS = [ ('from confluent_kafka.schema_registry.common import asyncinit', ''), ('@asyncinit', ''), @@ -63,6 +72,23 @@ def unasync_file_check(in_path, out_path): def unasync_dir(in_dir, out_dir, check_only=False): + # Create the output directory if it doesn't exist + os.makedirs(out_dir, exist_ok=True) + + # Create README.md in the sync directory + readme_path = os.path.join(out_dir, "README.md") + readme_content = """# Auto-generated Directory + +This directory contains auto-generated code. Do not edit these files directly. + +To make changes: +1. Edit the corresponding files in the sibling `_async` directory +2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory +""" + if not check_only: + with open(readme_path, "w") as f: + f.write(readme_content) + for dirpath, dirnames, filenames in os.walk(in_dir): for filename in filenames: if not filename.endswith('.py'): @@ -70,6 +96,8 @@ def unasync_dir(in_dir, out_dir, check_only=False): rel_dir = os.path.relpath(dirpath, in_dir) in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) + # Create the subdirectory if it doesn't exist + os.makedirs(os.path.dirname(out_path), exist_ok=True) print(in_path, '->', out_path) if check_only: unasync_file_check(in_path, out_path) @@ -107,22 +135,14 @@ def check_diff(sync_dir): def unasync(check=False): - async_dirs = [ - "src/confluent_kafka/schema_registry/_async", - "tests/integration/schema_registry/_async" - ] - sync_dirs = [ - "src/confluent_kafka/schema_registry/_sync", - "tests/integration/schema_registry/_sync" - ] print("Converting async code to sync code...") - for async_dir, sync_dir in zip(async_dirs, sync_dirs): + for async_dir, sync_dir in ASYNC_TO_SYNC: unasync_dir(async_dir, sync_dir, check_only=False) files_with_diff = [] if check: - for sync_dir in sync_dirs: + for _, sync_dir in ASYNC_TO_SYNC: files_with_diff.extend(check_diff(sync_dir)) if files_with_diff: From 05d2918c51a84ee5435ad7a1d67744ca179dfd1a Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 21:10:57 -0700 Subject: [PATCH 30/32] use timeout --- tests/common/_async/consumer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common/_async/consumer.py b/tests/common/_async/consumer.py index 0fc5b0233..81f036ca0 100644 --- a/tests/common/_async/consumer.py +++ b/tests/common/_async/consumer.py @@ -50,7 +50,7 @@ async def __anext__(self): async def poll(self, poll_timeout: int = -1): poll_timeout = None if poll_timeout == -1 else poll_timeout - async with asyncio.timeout(poll_timeout): + async with timeout(poll_timeout): while True: # Zero timeout here is what makes it non-blocking msg = super().poll(0) From 187b9192ef0732e61347957db46dfeb9f758c04b Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 21:52:16 -0700 Subject: [PATCH 31/32] Refactor unasync and add tests to prove functionality --- tests/test_unasync.py | 189 ++++++++++++++++++++++++++++++++++++++++++ tools/unasync.py | 146 ++++++++++++++++++-------------- 2 files changed, 273 insertions(+), 62 deletions(-) create mode 100644 tests/test_unasync.py diff --git a/tests/test_unasync.py b/tests/test_unasync.py new file mode 100644 index 000000000..a433175f6 --- /dev/null +++ b/tests/test_unasync.py @@ -0,0 +1,189 @@ +from tools.unasync import unasync, unasync_line, unasync_file_check + +import os +import tempfile + +import pytest + +@pytest.fixture +def temp_dirs(): + """Create temporary directories for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create async and sync directories + async_dir = os.path.join(temp_dir, "async") + sync_dir = os.path.join(temp_dir, "sync") + os.makedirs(async_dir) + os.makedirs(sync_dir) + yield async_dir, sync_dir + + +def test_unasync_line(): + """Test the unasync_line function with various inputs.""" + test_cases = [ + ("async def test():", "def test():"), + ("await some_func()", "some_func()"), + ("from confluent_kafka.schema_registry.common import asyncinit", ""), + ("@asyncinit", ""), + ("import asyncio", ""), + ("asyncio.sleep(1)", "time.sleep(1)"), + ("class AsyncTest:", "class Test:"), + ("class _AsyncTest:", "class _Test:"), + ("async_test_func", "test_func"), + ] + + for input_line, expected in test_cases: + assert unasync_line(input_line) == expected + + +def test_unasync_file_check(temp_dirs): + """Test the unasync_file_check function with various scenarios.""" + async_dir, sync_dir = temp_dirs + + # Test case 1: Files match + async_file = os.path.join(async_dir, "test1.py") + sync_file = os.path.join(sync_dir, "test1.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + time.sleep(1) +""") + + # This should return True + assert unasync_file_check(async_file, sync_file) is True + + # Test case 2: Files don't match + async_file = os.path.join(async_dir, "test2.py") + sync_file = os.path.join(sync_dir, "test2.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + # This is wrong + asyncio.sleep(1) +""") + + # This should return False + assert unasync_file_check(async_file, sync_file) is False + + # Test case 3: Files have different lengths + async_file = os.path.join(async_dir, "test3.py") + sync_file = os.path.join(sync_dir, "test3.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + + with open(async_file, "w") as f: + f.write("""async def test(): + await asyncio.sleep(1) + return "test" +""") + + with open(sync_file, "w") as f: + f.write("""def test(): + time.sleep(1) +""") + + # This should return False + assert unasync_file_check(async_file, sync_file) is False + + # Test case 4: File not found + with pytest.raises(ValueError, match="Error comparing"): + unasync_file_check("nonexistent.py", "also_nonexistent.py") + + +def test_unasync_generation(temp_dirs): + """Test the unasync generation functionality.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" + +class AsyncTest: + async def test_method(self): + await self.some_async() +""") + + # Run unasync with test directories + dir_pairs = [(async_dir, sync_dir)] + unasync(dir_pairs=dir_pairs, check=False) + + # Check if sync file was created + sync_file = os.path.join(sync_dir, "test.py") + assert os.path.exists(sync_file) + + # Check content + with open(sync_file, "r") as f: + content = f.read() + assert "async def" not in content + assert "await" not in content + assert "AsyncTest" not in content + assert "Test" in content + + # Check if README was created + readme_file = os.path.join(sync_dir, "README.md") + assert os.path.exists(readme_file) + with open(readme_file, "r") as f: + readme_content = f.read() + assert "Auto-generated Directory" in readme_content + assert "Do not edit these files directly" in readme_content + + +def test_unasync_check(temp_dirs): + """Test the unasync check functionality.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" +""") + + # Create an incorrect sync file + sync_file = os.path.join(sync_dir, "test.py") + os.makedirs(os.path.dirname(sync_file), exist_ok=True) + with open(sync_file, "w") as f: + f.write("""def test_func(): + time.sleep(1) + return "test" + # Extra line that shouldn't be here +""") + + # Run unasync check with test directories + dir_pairs = [(async_dir, sync_dir)] + with pytest.raises(SystemExit) as excinfo: + unasync(dir_pairs=dir_pairs, check=True) + assert excinfo.value.code == 1 + + +def test_unasync_missing_sync_file(temp_dirs): + """Test unasync check with missing sync files.""" + async_dir, sync_dir = temp_dirs + + # Create a test async file + test_file = os.path.join(async_dir, "test.py") + with open(test_file, "w") as f: + f.write("""async def test_func(): + await asyncio.sleep(1) + return "test" +""") + + # Run unasync check with test directories + dir_pairs = [(async_dir, sync_dir)] + with pytest.raises(SystemExit) as excinfo: + unasync(dir_pairs=dir_pairs, check=True) + assert excinfo.value.code == 1 diff --git a/tools/unasync.py b/tools/unasync.py index ff5c5fe19..173d1d48e 100644 --- a/tools/unasync.py +++ b/tools/unasync.py @@ -4,7 +4,7 @@ import re import sys import argparse -import subprocess +import difflib # List of directories to convert from async to sync # Each tuple contains the async directory and its sync counterpart @@ -59,22 +59,70 @@ def unasync_file(in_path, out_path): def unasync_file_check(in_path, out_path): - with open(in_path, "r") as in_file: + """Check if the sync file matches the expected generated content. + + Args: + in_path: Path to the async file + out_path: Path to the sync file + + Returns: + bool: True if files match, False if they don't + + Raises: + ValueError: If there's an error reading the files + """ + try: + with open(in_path, "r") as in_file: + async_content = in_file.read() + expected_content = "".join(unasync_line(line) for line in async_content.splitlines(keepends=True)) + with open(out_path, "r") as out_file: - for in_line, out_line in zip(in_file.readlines(), out_file.readlines()): - expected = unasync_line(in_line) - if out_line != expected: - print(f'unasync mismatch between {in_path!r} and {out_path!r}') - print(f'Async code: {in_line!r}') - print(f'Expected sync code: {expected!r}') - print(f'Actual sync code: {out_line!r}') - sys.exit(1) + actual_content = out_file.read() + + if actual_content != expected_content: + diff = difflib.unified_diff( + expected_content.splitlines(keepends=True), + actual_content.splitlines(keepends=True), + fromfile=in_path, + tofile=out_path, + n=3 # Show 3 lines of context + ) + print(''.join(diff)) + return False + return True + except Exception as e: + print(f"Error comparing {in_path} and {out_path}: {e}") + raise ValueError(f"Error comparing {in_path} and {out_path}: {e}") + + +def check_sync_files(dir_pairs): + """Check if all sync files match their expected generated content. + Returns a list of files that have differences.""" + files_with_diff = [] + + for async_dir, sync_dir in dir_pairs: + for dirpath, _, filenames in os.walk(async_dir): + for filename in filenames: + if not filename.endswith('.py'): + continue + rel_dir = os.path.relpath(dirpath, async_dir) + async_path = os.path.normpath(os.path.join(async_dir, rel_dir, filename)) + sync_path = os.path.normpath(os.path.join(sync_dir, rel_dir, filename)) + + if not os.path.exists(sync_path): + files_with_diff.append(sync_path) + continue + if not unasync_file_check(async_path, sync_path): + files_with_diff.append(sync_path) -def unasync_dir(in_dir, out_dir, check_only=False): + return files_with_diff + + +def unasync_dir(in_dir, out_dir): # Create the output directory if it doesn't exist os.makedirs(out_dir, exist_ok=True) - + # Create README.md in the sync directory readme_path = os.path.join(out_dir, "README.md") readme_content = """# Auto-generated Directory @@ -85,11 +133,10 @@ def unasync_dir(in_dir, out_dir, check_only=False): 1. Edit the corresponding files in the sibling `_async` directory 2. Run `python tools/unasync.py` to propagate the changes to this `_sync` directory """ - if not check_only: - with open(readme_path, "w") as f: - f.write(readme_content) - - for dirpath, dirnames, filenames in os.walk(in_dir): + with open(readme_path, "w") as f: + f.write(readme_content) + + for dirpath, _, filenames in os.walk(in_dir): for filename in filenames: if not filename.endswith('.py'): continue @@ -99,63 +146,38 @@ def unasync_dir(in_dir, out_dir, check_only=False): # Create the subdirectory if it doesn't exist os.makedirs(os.path.dirname(out_path), exist_ok=True) print(in_path, '->', out_path) - if check_only: - unasync_file_check(in_path, out_path) - else: - unasync_file(in_path, out_path) - + unasync_file(in_path, out_path) -def check_diff(sync_dir): - """Check if there are any differences in the sync directory. - Returns a list of files that have differences.""" - try: - # Get the list of files in the sync directory - result = subprocess.run(['git', 'ls-files', sync_dir], capture_output=True, text=True) - if result.returncode != 0: - print(f"Error listing files in {sync_dir}") - return [] - - files = result.stdout.strip().split('\n') - if not files or (len(files) == 1 and not files[0]): - print(f"No files found in {sync_dir}") - return [] - - # Check if any of these files have differences - files_with_diff = [] - for file in files: - if not file: # Skip empty lines - continue - diff_result = subprocess.run(['git', 'diff', '--quiet', file], capture_output=True, text=True) - if diff_result.returncode != 0: - files_with_diff.append(file) - return files_with_diff - except subprocess.CalledProcessError as e: - print(f"Error checking differences: {e}") - return [] +def unasync(dir_pairs=None, check=False): + """Convert async code to sync code. -def unasync(check=False): - - print("Converting async code to sync code...") - for async_dir, sync_dir in ASYNC_TO_SYNC: - unasync_dir(async_dir, sync_dir, check_only=False) + Args: + dir_pairs: List of (async_dir, sync_dir) tuples to process. If None, uses ASYNC_TO_SYNC. + check: If True, only check if sync files are up to date without modifying them. + """ + if dir_pairs is None: + dir_pairs = ASYNC_TO_SYNC files_with_diff = [] if check: - for _, sync_dir in ASYNC_TO_SYNC: - files_with_diff.extend(check_diff(sync_dir)) + files_with_diff = check_sync_files(dir_pairs) if files_with_diff: - print("\n⚠️ Detected changes to a _sync directory that are uncommitted.") - print("\nFiles with differences:") + print("\n⚠️ Detected differences between async and sync files.") + print("\nFiles that need to be regenerated:") for file in files_with_diff: print(f" - {file}") - print("\nPlease either:") - print("1. Commit the changes in the generated _sync files, or") - print("2. Revert the changes in the original _async files") + print("\nPlease run this script again (without the --check flag) to regenerate the sync files.") sys.exit(1) else: - print("\n✅ Conversion completed successfully!") + print("\n✅ All _sync directories are up to date!") + if not check: + print("Converting async code to sync code...") + for async_dir, sync_dir in dir_pairs: + unasync_dir(async_dir, sync_dir) + + print("\n✅ Generated sync code from async code.") if __name__ == '__main__': From 5e0d7be800ff0b3172e2a544a8135879090c4228 Mon Sep 17 00:00:00 2001 From: Rohit Sanjay Date: Thu, 29 May 2025 21:52:33 -0700 Subject: [PATCH 32/32] fix --- tests/test_unasync.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_unasync.py b/tests/test_unasync.py index a433175f6..ea475e04f 100644 --- a/tests/test_unasync.py +++ b/tests/test_unasync.py @@ -5,6 +5,7 @@ import pytest + @pytest.fixture def temp_dirs(): """Create temporary directories for testing."""