|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections.abc import Iterable, Callable, Sequence |
| 4 | +from collections import Counter |
| 5 | +from dataclasses import dataclass |
| 6 | +import inspect |
| 7 | +from functools import cached_property |
| 8 | + |
| 9 | +from typing import Any |
| 10 | + |
| 11 | + |
| 12 | +def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]): |
| 13 | + for node in nodes: |
| 14 | + input = node.evaluate(input) |
| 15 | + return input |
| 16 | + |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class ConversionNode: |
| 20 | + name: str |
| 21 | + required_keys: tuple[str, ...] |
| 22 | + output_keys: tuple[str, ...] |
| 23 | + trim_keys: bool |
| 24 | + |
| 25 | + def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]: |
| 26 | + if missing_keys := set(self.required_keys) - set(input_keys): |
| 27 | + raise ValueError(f"Missing keys: {missing_keys}") |
| 28 | + if self.trim_keys: |
| 29 | + return tuple(sorted(set(self.output_keys))) |
| 30 | + return tuple(sorted(set(input_keys) | set(self.output_keys))) |
| 31 | + |
| 32 | + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: |
| 33 | + if self.trim_keys: |
| 34 | + return {k: input[k] for k in self.output_keys} |
| 35 | + else: |
| 36 | + if missing_keys := set(self.output_keys) - set(input): |
| 37 | + raise ValueError(f"Missing keys: {missing_keys}") |
| 38 | + return input |
| 39 | + |
| 40 | + |
| 41 | +@dataclass |
| 42 | +class UnionConversionNode(ConversionNode): |
| 43 | + nodes: tuple[ConversionNode, ...] |
| 44 | + |
| 45 | + @classmethod |
| 46 | + def from_nodes(cls, name: str, *nodes: ConversionNode, trim_keys=False): |
| 47 | + required = tuple(set(k for n in nodes for k in n.required_keys)) |
| 48 | + output = Counter(k for n in nodes for k in n.output_keys) |
| 49 | + if duplicate := {k for k, v in output.items() if v > 1}: |
| 50 | + raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}") |
| 51 | + return cls(name, required, tuple(output), trim_keys, nodes) |
| 52 | + |
| 53 | + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: |
| 54 | + return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()}) |
| 55 | + |
| 56 | + |
| 57 | +@dataclass |
| 58 | +class RenameConversionNode(ConversionNode): |
| 59 | + mapping: dict[str, str] |
| 60 | + |
| 61 | + @classmethod |
| 62 | + def from_mapping(cls, name: str, mapping: dict[str, str], trim_keys=False): |
| 63 | + required = tuple(mapping) |
| 64 | + output = Counter(mapping.values()) |
| 65 | + if duplicate := {k for k, v in output.items() if v > 1}: |
| 66 | + raise ValueError(f"Duplicate output keys in mapping: {duplicate}") |
| 67 | + return cls(name, required, tuple(output), trim_keys, mapping) |
| 68 | + |
| 69 | + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: |
| 70 | + return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}}) |
| 71 | + |
| 72 | + |
| 73 | +@dataclass |
| 74 | +class FunctionConversionNode(ConversionNode): |
| 75 | + funcs: dict[str, Callable] |
| 76 | + |
| 77 | + @cached_property |
| 78 | + def _sigs(self): |
| 79 | + return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()} |
| 80 | + |
| 81 | + @classmethod |
| 82 | + def from_funcs(cls, name: str, funcs: dict[str, Callable], trim_keys=False): |
| 83 | + sigs = {k: inspect.signature(f) for k, f in funcs.items()} |
| 84 | + output = tuple(sigs) |
| 85 | + input = [] |
| 86 | + for v in sigs.values(): |
| 87 | + input.extend(v.parameters.keys()) |
| 88 | + input = tuple(set(input)) |
| 89 | + return cls(name, input, output, trim_keys, funcs) |
| 90 | + |
| 91 | + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: |
| 92 | + return super().evaluate( |
| 93 | + { |
| 94 | + **input, |
| 95 | + **{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()}, |
| 96 | + } |
| 97 | + ) |
0 commit comments