Skip to content

Commit b517f9f

Browse files
committed
Replace type guessing with type conversion
1 parent 7c7ab58 commit b517f9f

File tree

3 files changed

+41
-199
lines changed

3 files changed

+41
-199
lines changed

python/cocoindex/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
240240
elif type_info.kind == 'Union':
241241
if type_info.elem_type is not types.UnionType:
242242
raise ValueError("Union type must have a union-typed element type")
243-
encoded_type['types'] = [
243+
encoded_type['possible_types'] = [
244244
_encode_type(analyze_type_info(typ)) for typ in type_info.union_variant_types
245245
]
246246

src/base/value.rs

Lines changed: 5 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -934,60 +934,6 @@ impl BasicValue {
934934
.collect::<Result<Vec<_>>>()?;
935935
BasicValue::Vector(Arc::from(vec))
936936
}
937-
(v, BasicValueType::Union(typ)) => {
938-
let types = typ.types();
939-
940-
match v {
941-
serde_json::Value::Bool(b) if types.contains(&BasicValueType::Bool) => {
942-
BasicValue::Bool(b)
943-
}
944-
serde_json::Value::Number(n) => {
945-
if types.contains(&BasicValueType::Int64) {
946-
match n.as_i64() {
947-
Some(n) => return Ok(BasicValue::Int64(n)),
948-
None => {}
949-
}
950-
}
951-
952-
if types.contains(&BasicValueType::Float64) {
953-
match n.as_f64() {
954-
Some(n) => return Ok(BasicValue::Float64(n)),
955-
None => {}
956-
}
957-
}
958-
959-
if types.contains(&BasicValueType::Float32) {
960-
match n.as_f64().map(|v| v as f32) {
961-
Some(n) => return Ok(BasicValue::Float32(n)),
962-
None => {}
963-
}
964-
}
965-
966-
anyhow::bail!("Invalid number value {n}")
967-
}
968-
serde_json::Value::Object(obj) if types.contains(&BasicValueType::Range) => {
969-
let start = obj.get("start").and_then(|val| val.as_u64());
970-
let end = obj.get("end").and_then(|val| val.as_u64());
971-
972-
match (start, end) {
973-
(Some(start), Some(end)) => {
974-
BasicValue::Range(RangeValue::new(start as usize, end as usize))
975-
}
976-
_ => anyhow::bail!("Invalid range value")
977-
}
978-
}
979-
serde_json::Value::String(s) => {
980-
match types.parse_str(&s) {
981-
Ok(val) => return Ok(val),
982-
Err(_) => {}
983-
}
984-
985-
anyhow::bail!("Invalid string value \"{s}\"")
986-
}
987-
988-
_ => anyhow::bail!("Invalid union value {v}, expect type {}", types.iter().join(" | ")),
989-
}
990-
}
991937
(v, t) => {
992938
anyhow::bail!("Value and type not matched.\nTarget type {t:?}\nJSON value: {v}\n")
993939
}
@@ -1120,7 +1066,11 @@ impl Serialize for TypedValue<'_> {
11201066
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
11211067
match (self.t, self.v) {
11221068
(_, Value::Null) => serializer.serialize_none(),
1123-
(ValueType::Basic(_), v) => v.serialize(serializer),
1069+
(ValueType::Basic(typ), v) => match typ {
1070+
BasicValueType::Union(s) => {
1071+
}
1072+
_ => v.serialize(serializer),
1073+
},
11241074
(ValueType::Struct(s), Value::Struct(field_values)) => TypedFieldsValue {
11251075
schema: &s.fields,
11261076
values_iter: field_values.fields.iter(),

src/utils/union.rs

Lines changed: 35 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,50 @@
1-
use std::{str::FromStr, sync::Arc};
2-
31
use crate::{base::{schema::BasicValueType, value::BasicValue}, prelude::*};
42

5-
#[derive(Debug, Clone)]
6-
pub enum UnionParseResult {
7-
Union(UnionType),
8-
Single(BasicValueType),
9-
}
10-
113
/// Union type helper storing an auto-sorted set of types excluding `Union`
124
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
135
pub struct UnionType {
14-
types: BTreeSet<BasicValueType>,
6+
possible_types: BTreeSet<BasicValueType>,
157
}
168

179
impl UnionType {
18-
fn new() -> Self {
19-
Self { types: BTreeSet::new() }
20-
}
21-
22-
pub fn types(&self) -> &BTreeSet<BasicValueType> {
23-
&self.types
24-
}
25-
26-
pub fn insert(&mut self, value: BasicValueType) -> bool {
27-
match value {
28-
BasicValueType::Union(union_type) => {
29-
let mut inserted = false;
30-
31-
// Unpack nested union
32-
for item in union_type.types.into_iter() {
33-
// Recursively insert underlying types
34-
inserted = self.insert(item) || inserted;
10+
pub fn possible_types(&self) -> &BTreeSet<BasicValueType> {
11+
&self.possible_types
12+
}
13+
14+
pub fn match_type(&self, value: &BasicValue) -> Option<BasicValueType> {
15+
// TODO: Use a general converter
16+
let typ = match value {
17+
BasicValue::Bytes(_) => Some(BasicValueType::Bytes),
18+
BasicValue::Str(_) => Some(BasicValueType::Str),
19+
BasicValue::Bool(_) => Some(BasicValueType::Bool),
20+
BasicValue::Int64(_) => Some(BasicValueType::Int64),
21+
BasicValue::Float32(_) => Some(BasicValueType::Float32),
22+
BasicValue::Float64(_) => Some(BasicValueType::Float64),
23+
BasicValue::Range(_) => Some(BasicValueType::Range),
24+
BasicValue::Uuid(_) => Some(BasicValueType::Uuid),
25+
BasicValue::Date(_) => Some(BasicValueType::Date),
26+
BasicValue::Time(_) => Some(BasicValueType::Time),
27+
BasicValue::LocalDateTime(_) => Some(BasicValueType::LocalDateTime),
28+
BasicValue::OffsetDateTime(_) => Some(BasicValueType::OffsetDateTime),
29+
BasicValue::Json(_) => Some(BasicValueType::Json),
30+
BasicValue::Vector(v) => {
31+
match v.first() {
32+
Some(first_elem) => self.match_type(first_elem),
33+
None => None,
3534
}
36-
37-
inserted
3835
}
36+
};
3937

40-
other => self.types.insert(other),
41-
}
42-
}
43-
44-
fn resolve(self) -> Result<UnionParseResult> {
45-
if self.types().is_empty() {
46-
anyhow::bail!("The union is empty");
47-
}
48-
49-
if self.types().len() == 1 {
50-
let mut type_tree: BTreeSet<BasicValueType> = self.into();
51-
return Ok(UnionParseResult::Single(type_tree.pop_first().unwrap()));
52-
}
53-
54-
Ok(UnionParseResult::Union(self))
55-
}
56-
57-
/// Move an iterable and parse it into a union type.
58-
/// If there is only one single unique type, it returns a single `BasicValueType`.
59-
pub fn parse_from<T>(
60-
input: impl IntoIterator<Item = BasicValueType, IntoIter = T>,
61-
) -> Result<UnionParseResult>
62-
where
63-
T: Iterator<Item = BasicValueType>,
64-
{
65-
let mut union = Self::new();
66-
67-
for typ in input {
68-
union.insert(typ);
69-
}
70-
71-
union.resolve()
72-
}
73-
74-
/// Assume the input already contains multiple unique types, panic otherwise.
75-
///
76-
/// This method is meant for streamlining the code for test cases.
77-
/// Use `parse_from()` instead unless you know the input.
78-
pub fn coerce_from<T>(
79-
input: impl IntoIterator<Item = BasicValueType, IntoIter = T>,
80-
) -> Self
81-
where
82-
T: Iterator<Item = BasicValueType>,
83-
{
84-
match Self::parse_from(input) {
85-
Ok(UnionParseResult::Union(union)) => union,
86-
_ => panic!("Do not use `coerce_from()` for basic type lists that can possibly be one type."),
87-
}
38+
typ.and_then(|typ| {
39+
self.possible_types()
40+
.contains(&typ)
41+
.then_some(typ)
42+
})
8843
}
8944
}
9045

91-
impl Into<BTreeSet<BasicValueType>> for UnionType {
92-
fn into(self) -> BTreeSet<BasicValueType> {
93-
self.types
94-
}
95-
}
96-
97-
pub trait ParseStr {
98-
type Out;
99-
type Err;
100-
101-
fn parse_str(&self, value: &str) -> Result<Self::Out, Self::Err>;
102-
}
103-
104-
impl ParseStr for BTreeSet<BasicValueType> {
105-
type Out = BasicValue;
106-
type Err = anyhow::Error;
107-
108-
/// Try parsing the str value to each possible type, and return the first successful result
109-
fn parse_str(&self, value: &str) -> Result<BasicValue> {
110-
// Try parsing the value in the reversed order of the enum elements
111-
for typ in self.iter().rev() {
112-
match typ {
113-
BasicValueType::Uuid => {
114-
match value.parse().map(BasicValue::Uuid) {
115-
Ok(ret) => return Ok(ret),
116-
Err(_) => {}
117-
}
118-
}
119-
BasicValueType::OffsetDateTime => {
120-
match value.parse().map(BasicValue::OffsetDateTime) {
121-
Ok(ret) => return Ok(ret),
122-
Err(_) => {}
123-
}
124-
}
125-
BasicValueType::LocalDateTime => {
126-
match value.parse().map(BasicValue::LocalDateTime) {
127-
Ok(ret) => return Ok(ret),
128-
Err(_) => {}
129-
}
130-
}
131-
BasicValueType::Date => {
132-
match value.parse().map(BasicValue::Date) {
133-
Ok(ret) => return Ok(ret),
134-
Err(_) => {}
135-
}
136-
}
137-
BasicValueType::Time => {
138-
match value.parse().map(BasicValue::Time) {
139-
Ok(ret) => return Ok(ret),
140-
Err(_) => {}
141-
}
142-
}
143-
BasicValueType::Json => {
144-
match serde_json::Value::from_str(value) {
145-
Ok(ret) => return Ok(BasicValue::Json(ret.into())),
146-
Err(_) => {}
147-
}
148-
}
149-
BasicValueType::Str => {
150-
return Ok(BasicValue::Str(Arc::from(value)));
151-
}
152-
_ => {}
153-
}
154-
}
155-
156-
anyhow::bail!("Cannot parse \"{}\"", value)
157-
}
46+
#[derive(Debug, Clone, Serialize)]
47+
pub struct UnionVariant<'a> {
48+
pub tag_id: &'a str,
49+
pub value: &'a BasicValue,
15850
}

0 commit comments

Comments
 (0)