Skip to content

Commit b1fa5d9

Browse files
authored
Refactor for async support (#1971)
* Refactor for async support * remove avro module import * Add top level __all__ to refactored modules * reduce diffs * refactor * formatting * style fix * fix flake8 * fix * fix flake8 * revert * reduce diff
1 parent e6cec6b commit b1fa5d9

25 files changed

+5733
-5405
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from setuptools import setup
5-
from distutils.core import Extension
5+
from setuptools import Extension
66
import platform
77

88
work_dir = os.path.dirname(os.path.realpath(__file__))

src/confluent_kafka/schema_registry/_sync/avro.py

Lines changed: 599 additions & 0 deletions
Large diffs are not rendered by default.

src/confluent_kafka/schema_registry/_sync/json_schema.py

Lines changed: 657 additions & 0 deletions
Large diffs are not rendered by default.

src/confluent_kafka/schema_registry/_sync/protobuf.py

Lines changed: 714 additions & 0 deletions
Large diffs are not rendered by default.

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 1157 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2024 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
import logging
20+
from typing import List, Optional, Set, Dict, Any
21+
22+
from confluent_kafka.schema_registry import RegisteredSchema
23+
from confluent_kafka.schema_registry.common.serde import ErrorAction, \
24+
FieldTransformer, Migration, NoneAction, RuleAction, \
25+
RuleConditionError, RuleContext, RuleError, SchemaId
26+
from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \
27+
Rule, RuleKind, Schema, RuleSet
28+
from confluent_kafka.serialization import Serializer, Deserializer, \
29+
SerializationContext, SerializationError
30+
31+
__all__ = [
32+
'BaseSerde',
33+
'BaseSerializer',
34+
'BaseDeserializer',
35+
]
36+
37+
log = logging.getLogger(__name__)
38+
39+
40+
class BaseSerde(object):
41+
__slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata',
42+
'_registry', '_rule_registry', '_subject_name_func',
43+
'_field_transformer']
44+
45+
def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]:
46+
if self._use_schema_id is not None:
47+
schema = self._registry.get_schema(self._use_schema_id, subject, fmt)
48+
return self._registry.lookup_schema(subject, schema, False, True)
49+
if self._use_latest_with_metadata is not None:
50+
return self._registry.get_latest_with_metadata(
51+
subject, self._use_latest_with_metadata, True, fmt)
52+
if self._use_latest_version:
53+
return self._registry.get_latest_version(subject, fmt)
54+
return None
55+
56+
def _execute_rules(
57+
self, ser_ctx: SerializationContext, subject: str,
58+
rule_mode: RuleMode,
59+
source: Optional[Schema], target: Optional[Schema],
60+
message: Any, inline_tags: Optional[Dict[str, Set[str]]],
61+
field_transformer: Optional[FieldTransformer]
62+
) -> Any:
63+
if message is None or target is None:
64+
return message
65+
rules: Optional[List[Rule]] = None
66+
if rule_mode == RuleMode.UPGRADE:
67+
if target is not None and target.rule_set is not None:
68+
rules = target.rule_set.migration_rules
69+
elif rule_mode == RuleMode.DOWNGRADE:
70+
if source is not None and source.rule_set is not None:
71+
rules = source.rule_set.migration_rules
72+
rules = rules[:] if rules else []
73+
rules.reverse()
74+
else:
75+
if target is not None and target.rule_set is not None:
76+
rules = target.rule_set.domain_rules
77+
if rule_mode == RuleMode.READ:
78+
# Execute read rules in reverse order for symmetry
79+
rules = rules[:] if rules else []
80+
rules.reverse()
81+
82+
if not rules:
83+
return message
84+
85+
for index in range(len(rules)):
86+
rule = rules[index]
87+
if self._is_disabled(rule):
88+
continue
89+
if rule.mode == RuleMode.WRITEREAD:
90+
if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE:
91+
continue
92+
elif rule.mode == RuleMode.UPDOWN:
93+
if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE:
94+
continue
95+
elif rule.mode != rule_mode:
96+
continue
97+
98+
ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule,
99+
index, rules, inline_tags, field_transformer)
100+
rule_executor = self._rule_registry.get_executor(rule.type.upper())
101+
if rule_executor is None:
102+
self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message,
103+
RuleError(f"Could not find rule executor of type {rule.type}"),
104+
'ERROR')
105+
return message
106+
try:
107+
result = rule_executor.transform(ctx, message)
108+
if rule.kind == RuleKind.CONDITION:
109+
if not result:
110+
raise RuleConditionError(rule)
111+
elif rule.kind == RuleKind.TRANSFORM:
112+
message = result
113+
self._run_action(
114+
ctx, rule_mode, rule,
115+
self._get_on_failure(rule) if message is None else self._get_on_success(rule),
116+
message, None,
117+
'ERROR' if message is None else 'NONE')
118+
except SerializationError:
119+
raise
120+
except Exception as e:
121+
self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule),
122+
message, e, 'ERROR')
123+
return message
124+
125+
def _get_on_success(self, rule: Rule) -> Optional[str]:
126+
override = self._rule_registry.get_override(rule.type)
127+
if override is not None and override.on_success is not None:
128+
return override.on_success
129+
return rule.on_success
130+
131+
def _get_on_failure(self, rule: Rule) -> Optional[str]:
132+
override = self._rule_registry.get_override(rule.type)
133+
if override is not None and override.on_failure is not None:
134+
return override.on_failure
135+
return rule.on_failure
136+
137+
def _is_disabled(self, rule: Rule) -> Optional[bool]:
138+
override = self._rule_registry.get_override(rule.type)
139+
if override is not None and override.disabled is not None:
140+
return override.disabled
141+
return rule.disabled
142+
143+
def _run_action(
144+
self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule,
145+
action: Optional[str], message: Any,
146+
ex: Optional[Exception], default_action: str
147+
):
148+
action_name = self._get_rule_action_name(rule, rule_mode, action)
149+
if action_name is None:
150+
action_name = default_action
151+
rule_action = self._get_rule_action(ctx, action_name)
152+
if rule_action is None:
153+
log.error("Could not find rule action of type %s", action_name)
154+
raise RuleError(f"Could not find rule action of type {action_name}")
155+
try:
156+
rule_action.run(ctx, message, ex)
157+
except SerializationError:
158+
raise
159+
except Exception as e:
160+
log.warning("Could not run post-rule action %s: %s", action_name, e)
161+
162+
def _get_rule_action_name(
163+
self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str]
164+
) -> Optional[str]:
165+
if action_name is None or action_name == "":
166+
return None
167+
if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name:
168+
parts = action_name.split(',')
169+
if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE):
170+
return parts[0]
171+
elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE):
172+
return parts[1]
173+
return action_name
174+
175+
def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]:
176+
if action_name == 'ERROR':
177+
return ErrorAction()
178+
elif action_name == 'NONE':
179+
return NoneAction()
180+
return self._rule_registry.get_action(action_name)
181+
182+
183+
class BaseSerializer(BaseSerde, Serializer):
184+
__slots__ = ['_auto_register', '_normalize_schemas', '_schema_id_serializer']
185+
186+
187+
class BaseDeserializer(BaseSerde, Deserializer):
188+
__slots__ = ['_schema_id_deserializer']
189+
190+
def _get_writer_schema(self, schema_id: SchemaId, subject: Optional[str] = None,
191+
fmt: Optional[str] = None) -> Schema:
192+
if schema_id.id is not None:
193+
return self._registry.get_schema(schema_id.id, subject, fmt)
194+
elif schema_id.guid is not None:
195+
return self._registry.get_schema_by_guid(str(schema_id.guid), fmt)
196+
else:
197+
raise SerializationError("Schema ID or GUID is not set")
198+
199+
def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool:
200+
if rule_set is None:
201+
return False
202+
if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE):
203+
return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN
204+
for rule in rule_set.migration_rules or [])
205+
elif mode == RuleMode.UPDOWN:
206+
return any(rule.mode == mode for rule in rule_set.migration_rules or [])
207+
elif mode in (RuleMode.WRITE, RuleMode.READ):
208+
return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD
209+
for rule in rule_set.domain_rules or [])
210+
elif mode == RuleMode.WRITEREAD:
211+
return any(rule.mode == mode for rule in rule_set.migration_rules or [])
212+
return False
213+
214+
def _get_migrations(
215+
self, subject: str, source_info: Schema,
216+
target: RegisteredSchema, fmt: Optional[str]
217+
) -> List[Migration]:
218+
source = self._registry.lookup_schema(subject, source_info, False, True)
219+
migrations = []
220+
if source.version < target.version:
221+
migration_mode = RuleMode.UPGRADE
222+
first = source
223+
last = target
224+
elif source.version > target.version:
225+
migration_mode = RuleMode.DOWNGRADE
226+
first = target
227+
last = source
228+
else:
229+
return migrations
230+
previous: Optional[RegisteredSchema] = None
231+
versions = self._get_schemas_between(subject, first, last, fmt)
232+
for i in range(len(versions)):
233+
version = versions[i]
234+
if i == 0:
235+
previous = version
236+
continue
237+
if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode):
238+
if migration_mode == RuleMode.UPGRADE:
239+
migration = Migration(migration_mode, previous, version)
240+
else:
241+
migration = Migration(migration_mode, version, previous)
242+
migrations.append(migration)
243+
previous = version
244+
if migration_mode == RuleMode.DOWNGRADE:
245+
migrations.reverse()
246+
return migrations
247+
248+
def _get_schemas_between(
249+
self, subject: str, first: RegisteredSchema,
250+
last: RegisteredSchema, fmt: Optional[str] = None
251+
) -> List[RegisteredSchema]:
252+
if last.version - first.version <= 1:
253+
return [first, last]
254+
version1 = first.version
255+
version2 = last.version
256+
result = [first]
257+
for i in range(version1 + 1, version2):
258+
result.append(self._registry.get_version(subject, i, True, fmt))
259+
result.append(last)
260+
return result
261+
262+
def _execute_migrations(
263+
self, ser_ctx: SerializationContext, subject: str,
264+
migrations: List[Migration], message: Any
265+
) -> Any:
266+
for migration in migrations:
267+
message = self._execute_rules(ser_ctx, subject, migration.rule_mode,
268+
migration.source.schema, migration.target.schema,
269+
message, None, None)
270+
return message

0 commit comments

Comments
 (0)