Skip to content

Commit 14346dc

Browse files
committed
Initial implementation of conversion node
Reference #26
1 parent 446bc5c commit 14346dc

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

data_prototype/conversion_node.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable, Callable, Sequence
4+
from collections import Counter
5+
from dataclasses import dataclass
6+
import inspect
7+
from functools import cached_property
8+
9+
from typing import Any
10+
11+
12+
def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
13+
for node in nodes:
14+
input = node.evaluate(input)
15+
return input
16+
17+
18+
@dataclass
19+
class ConversionNode:
20+
name: str
21+
required_keys: tuple[str, ...]
22+
output_keys: tuple[str, ...]
23+
trim_keys: bool
24+
25+
def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]:
26+
if missing_keys := set(self.required_keys) - set(input_keys):
27+
raise ValueError(f"Missing keys: {missing_keys}")
28+
if self.trim_keys:
29+
return tuple(sorted(set(self.output_keys)))
30+
return tuple(sorted(set(input_keys) | set(self.output_keys)))
31+
32+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
33+
if self.trim_keys:
34+
return {k: input[k] for k in self.output_keys}
35+
else:
36+
if missing_keys := set(self.output_keys) - set(input):
37+
raise ValueError(f"Missing keys: {missing_keys}")
38+
return input
39+
40+
41+
@dataclass
42+
class UnionConversionNode(ConversionNode):
43+
nodes: tuple[ConversionNode, ...]
44+
45+
@classmethod
46+
def from_nodes(cls, name: str, *nodes: ConversionNode, trim_keys=False):
47+
required = tuple(set(k for n in nodes for k in n.required_keys))
48+
output = Counter(k for n in nodes for k in n.output_keys)
49+
if duplicate := {k for k, v in output.items() if v > 1}:
50+
raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}")
51+
return cls(name, required, tuple(output), trim_keys, nodes)
52+
53+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
54+
return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()})
55+
56+
57+
@dataclass
58+
class RenameConversionNode(ConversionNode):
59+
mapping: dict[str, str]
60+
61+
@classmethod
62+
def from_mapping(cls, name: str, mapping: dict[str, str], trim_keys=False):
63+
required = tuple(mapping)
64+
output = Counter(mapping.values())
65+
if duplicate := {k for k, v in output.items() if v > 1}:
66+
raise ValueError(f"Duplicate output keys in mapping: {duplicate}")
67+
return cls(name, required, tuple(output), trim_keys, mapping)
68+
69+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
70+
return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}})
71+
72+
73+
@dataclass
74+
class FunctionConversionNode(ConversionNode):
75+
funcs: dict[str, Callable]
76+
77+
@cached_property
78+
def _sigs(self):
79+
return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()}
80+
81+
@classmethod
82+
def from_funcs(cls, name: str, funcs: dict[str, Callable], trim_keys=False):
83+
sigs = {k: inspect.signature(f) for k, f in funcs.items()}
84+
output = tuple(sigs)
85+
input = []
86+
for v in sigs.values():
87+
input.extend(v.parameters.keys())
88+
input = tuple(set(input))
89+
return cls(name, input, output, trim_keys, funcs)
90+
91+
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
92+
return super().evaluate(
93+
{
94+
**input,
95+
**{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()},
96+
}
97+
)

0 commit comments

Comments
 (0)