diff --git a/elasticsearch/dsl/aggs.py b/elasticsearch/dsl/aggs.py index ba5150803..a20373163 100644 --- a/elasticsearch/dsl/aggs.py +++ b/elasticsearch/dsl/aggs.py @@ -35,6 +35,7 @@ from elastic_transport.client_utils import DEFAULT +from . import wrappers from .query import Query from .response.aggs import AggResponse, BucketData, FieldBucketData, TopHitsData from .utils import _R, AttrDict, DslBase @@ -761,7 +762,7 @@ def __init__( *, after: Union[ Mapping[ - Union[str, "InstrumentedField"], Union[int, float, str, bool, None, Any] + Union[str, "InstrumentedField"], Union[int, float, str, bool, None] ], "DefaultType", ] = DEFAULT, @@ -958,7 +959,7 @@ def __init__( format: Union[str, "DefaultType"] = DEFAULT, missing: Union[str, int, float, bool, "DefaultType"] = DEFAULT, ranges: Union[ - Sequence["types.DateRangeExpression"], + Sequence["wrappers.AggregationRange"], Sequence[Dict[str, Any]], "DefaultType", ] = DEFAULT, @@ -1347,7 +1348,9 @@ def __init__( "DefaultType", ] = DEFAULT, ranges: Union[ - Sequence["types.AggregationRange"], Sequence[Dict[str, Any]], "DefaultType" + Sequence["wrappers.AggregationRange"], + Sequence[Dict[str, Any]], + "DefaultType", ] = DEFAULT, unit: Union[ Literal["in", "ft", "yd", "mi", "nmi", "km", "m", "cm", "mm"], "DefaultType" @@ -2657,7 +2660,9 @@ def __init__( field: Union[str, "InstrumentedField", "DefaultType"] = DEFAULT, missing: Union[int, "DefaultType"] = DEFAULT, ranges: Union[ - Sequence["types.AggregationRange"], Sequence[Dict[str, Any]], "DefaultType" + Sequence["wrappers.AggregationRange"], + Sequence[Dict[str, Any]], + "DefaultType", ] = DEFAULT, script: Union["types.Script", Dict[str, Any], "DefaultType"] = DEFAULT, keyed: Union[bool, "DefaultType"] = DEFAULT, diff --git a/elasticsearch/dsl/faceted_search_base.py b/elasticsearch/dsl/faceted_search_base.py index 5caa041bf..47c887341 100644 --- a/elasticsearch/dsl/faceted_search_base.py +++ b/elasticsearch/dsl/faceted_search_base.py @@ -42,7 +42,7 @@ from .response.aggs import BucketData from .search_base import SearchBase -FilterValueType = Union[str, datetime, Sequence[str]] +FilterValueType = Union[str, int, float, bool] __all__ = [ "FacetedSearchBase", @@ -396,10 +396,10 @@ def add_filter( ] # remember the filter values for use in FacetedResponse - self.filter_values[name] = filter_values # type: ignore[assignment] + self.filter_values[name] = filter_values # get the filter from the facet - f = self.facets[name].add_filter(filter_values) # type: ignore[arg-type] + f = self.facets[name].add_filter(filter_values) if f is None: return diff --git a/elasticsearch/dsl/field.py b/elasticsearch/dsl/field.py index 7fcc9ada5..726fbe358 100644 --- a/elasticsearch/dsl/field.py +++ b/elasticsearch/dsl/field.py @@ -437,7 +437,9 @@ def __init__( doc_class: Union[Type["InnerDoc"], "DefaultType"] = DEFAULT, *args: Any, enabled: Union[bool, "DefaultType"] = DEFAULT, - subobjects: Union[bool, "DefaultType"] = DEFAULT, + subobjects: Union[ + Literal["true", "false", "auto"], bool, "DefaultType" + ] = DEFAULT, copy_to: Union[ Union[str, "InstrumentedField"], Sequence[Union[str, "InstrumentedField"]], @@ -762,6 +764,11 @@ class Boolean(Field): :arg fielddata: :arg index: :arg null_value: + :arg ignore_malformed: + :arg script: + :arg on_script_error: + :arg time_series_dimension: For internal use by Elastic only. Marks + the field as a time series dimension. Defaults to false. :arg doc_values: :arg copy_to: :arg store: @@ -789,6 +796,10 @@ def __init__( ] = DEFAULT, index: Union[bool, "DefaultType"] = DEFAULT, null_value: Union[bool, "DefaultType"] = DEFAULT, + ignore_malformed: Union[bool, "DefaultType"] = DEFAULT, + script: Union["types.Script", Dict[str, Any], "DefaultType"] = DEFAULT, + on_script_error: Union[Literal["fail", "continue"], "DefaultType"] = DEFAULT, + time_series_dimension: Union[bool, "DefaultType"] = DEFAULT, doc_values: Union[bool, "DefaultType"] = DEFAULT, copy_to: Union[ Union[str, "InstrumentedField"], @@ -816,6 +827,14 @@ def __init__( kwargs["index"] = index if null_value is not DEFAULT: kwargs["null_value"] = null_value + if ignore_malformed is not DEFAULT: + kwargs["ignore_malformed"] = ignore_malformed + if script is not DEFAULT: + kwargs["script"] = script + if on_script_error is not DEFAULT: + kwargs["on_script_error"] = on_script_error + if time_series_dimension is not DEFAULT: + kwargs["time_series_dimension"] = time_series_dimension if doc_values is not DEFAULT: kwargs["doc_values"] = doc_values if copy_to is not DEFAULT: @@ -1092,6 +1111,56 @@ def __init__( super().__init__(*args, **kwargs) +class CountedKeyword(Field): + """ + :arg index: + :arg meta: Metadata about the field. + :arg properties: + :arg ignore_above: + :arg dynamic: + :arg fields: + :arg synthetic_source_keep: + """ + + name = "counted_keyword" + _param_defs = { + "properties": {"type": "field", "hash": True}, + "fields": {"type": "field", "hash": True}, + } + + def __init__( + self, + *args: Any, + index: Union[bool, "DefaultType"] = DEFAULT, + meta: Union[Mapping[str, str], "DefaultType"] = DEFAULT, + properties: Union[Mapping[str, Field], "DefaultType"] = DEFAULT, + ignore_above: Union[int, "DefaultType"] = DEFAULT, + dynamic: Union[ + Literal["strict", "runtime", "true", "false"], bool, "DefaultType" + ] = DEFAULT, + fields: Union[Mapping[str, Field], "DefaultType"] = DEFAULT, + synthetic_source_keep: Union[ + Literal["none", "arrays", "all"], "DefaultType" + ] = DEFAULT, + **kwargs: Any, + ): + if index is not DEFAULT: + kwargs["index"] = index + if meta is not DEFAULT: + kwargs["meta"] = meta + if properties is not DEFAULT: + kwargs["properties"] = properties + if ignore_above is not DEFAULT: + kwargs["ignore_above"] = ignore_above + if dynamic is not DEFAULT: + kwargs["dynamic"] = dynamic + if fields is not DEFAULT: + kwargs["fields"] = fields + if synthetic_source_keep is not DEFAULT: + kwargs["synthetic_source_keep"] = synthetic_source_keep + super().__init__(*args, **kwargs) + + class Date(Field): """ :arg default_timezone: timezone that will be automatically used for tz-naive values @@ -1101,6 +1170,8 @@ class Date(Field): :arg format: :arg ignore_malformed: :arg index: + :arg script: + :arg on_script_error: :arg null_value: :arg precision_step: :arg locale: @@ -1133,6 +1204,8 @@ def __init__( format: Union[str, "DefaultType"] = DEFAULT, ignore_malformed: Union[bool, "DefaultType"] = DEFAULT, index: Union[bool, "DefaultType"] = DEFAULT, + script: Union["types.Script", Dict[str, Any], "DefaultType"] = DEFAULT, + on_script_error: Union[Literal["fail", "continue"], "DefaultType"] = DEFAULT, null_value: Any = DEFAULT, precision_step: Union[int, "DefaultType"] = DEFAULT, locale: Union[str, "DefaultType"] = DEFAULT, @@ -1165,6 +1238,10 @@ def __init__( kwargs["ignore_malformed"] = ignore_malformed if index is not DEFAULT: kwargs["index"] = index + if script is not DEFAULT: + kwargs["script"] = script + if on_script_error is not DEFAULT: + kwargs["on_script_error"] = on_script_error if null_value is not DEFAULT: kwargs["null_value"] = null_value if precision_step is not DEFAULT: @@ -1229,6 +1306,8 @@ class DateNanos(Field): :arg format: :arg ignore_malformed: :arg index: + :arg script: + :arg on_script_error: :arg null_value: :arg precision_step: :arg doc_values: @@ -1255,6 +1334,8 @@ def __init__( format: Union[str, "DefaultType"] = DEFAULT, ignore_malformed: Union[bool, "DefaultType"] = DEFAULT, index: Union[bool, "DefaultType"] = DEFAULT, + script: Union["types.Script", Dict[str, Any], "DefaultType"] = DEFAULT, + on_script_error: Union[Literal["fail", "continue"], "DefaultType"] = DEFAULT, null_value: Any = DEFAULT, precision_step: Union[int, "DefaultType"] = DEFAULT, doc_values: Union[bool, "DefaultType"] = DEFAULT, @@ -1284,6 +1365,10 @@ def __init__( kwargs["ignore_malformed"] = ignore_malformed if index is not DEFAULT: kwargs["index"] = index + if script is not DEFAULT: + kwargs["script"] = script + if on_script_error is not DEFAULT: + kwargs["on_script_error"] = on_script_error if null_value is not DEFAULT: kwargs["null_value"] = null_value if precision_step is not DEFAULT: @@ -3068,6 +3153,76 @@ def __init__( super().__init__(*args, **kwargs) +class Passthrough(Field): + """ + :arg enabled: + :arg priority: + :arg time_series_dimension: + :arg copy_to: + :arg store: + :arg meta: Metadata about the field. + :arg properties: + :arg ignore_above: + :arg dynamic: + :arg fields: + :arg synthetic_source_keep: + """ + + name = "passthrough" + _param_defs = { + "properties": {"type": "field", "hash": True}, + "fields": {"type": "field", "hash": True}, + } + + def __init__( + self, + *args: Any, + enabled: Union[bool, "DefaultType"] = DEFAULT, + priority: Union[int, "DefaultType"] = DEFAULT, + time_series_dimension: Union[bool, "DefaultType"] = DEFAULT, + copy_to: Union[ + Union[str, "InstrumentedField"], + Sequence[Union[str, "InstrumentedField"]], + "DefaultType", + ] = DEFAULT, + store: Union[bool, "DefaultType"] = DEFAULT, + meta: Union[Mapping[str, str], "DefaultType"] = DEFAULT, + properties: Union[Mapping[str, Field], "DefaultType"] = DEFAULT, + ignore_above: Union[int, "DefaultType"] = DEFAULT, + dynamic: Union[ + Literal["strict", "runtime", "true", "false"], bool, "DefaultType" + ] = DEFAULT, + fields: Union[Mapping[str, Field], "DefaultType"] = DEFAULT, + synthetic_source_keep: Union[ + Literal["none", "arrays", "all"], "DefaultType" + ] = DEFAULT, + **kwargs: Any, + ): + if enabled is not DEFAULT: + kwargs["enabled"] = enabled + if priority is not DEFAULT: + kwargs["priority"] = priority + if time_series_dimension is not DEFAULT: + kwargs["time_series_dimension"] = time_series_dimension + if copy_to is not DEFAULT: + kwargs["copy_to"] = str(copy_to) + if store is not DEFAULT: + kwargs["store"] = store + if meta is not DEFAULT: + kwargs["meta"] = meta + if properties is not DEFAULT: + kwargs["properties"] = properties + if ignore_above is not DEFAULT: + kwargs["ignore_above"] = ignore_above + if dynamic is not DEFAULT: + kwargs["dynamic"] = dynamic + if fields is not DEFAULT: + kwargs["fields"] = fields + if synthetic_source_keep is not DEFAULT: + kwargs["synthetic_source_keep"] = synthetic_source_keep + super().__init__(*args, **kwargs) + + class Percolator(Field): """ :arg meta: Metadata about the field. diff --git a/elasticsearch/dsl/query.py b/elasticsearch/dsl/query.py index 6e87f926c..1282d3b02 100644 --- a/elasticsearch/dsl/query.py +++ b/elasticsearch/dsl/query.py @@ -1083,6 +1083,8 @@ class Knn(Query): :arg filter: Filters for the kNN search query :arg similarity: The minimum similarity for a vector to be considered a match + :arg rescore_vector: Apply oversampling and rescoring to quantized + vectors * :arg boost: Floating point number used to decrease or increase the relevance scores of the query. Boost values are relative to the default value of 1.0. A boost value between 0 and 1.0 decreases @@ -1108,6 +1110,9 @@ def __init__( k: Union[int, "DefaultType"] = DEFAULT, filter: Union[Query, Sequence[Query], "DefaultType"] = DEFAULT, similarity: Union[float, "DefaultType"] = DEFAULT, + rescore_vector: Union[ + "types.RescoreVector", Dict[str, Any], "DefaultType" + ] = DEFAULT, boost: Union[float, "DefaultType"] = DEFAULT, _name: Union[str, "DefaultType"] = DEFAULT, **kwargs: Any, @@ -1120,6 +1125,7 @@ def __init__( k=k, filter=filter, similarity=similarity, + rescore_vector=rescore_vector, boost=boost, _name=_name, **kwargs, @@ -2650,7 +2656,7 @@ def __init__( self, _field: Union[str, "InstrumentedField", "DefaultType"] = DEFAULT, _value: Union[ - Sequence[Union[int, float, str, bool, None, Any]], + Sequence[Union[int, float, str, bool, None]], "types.TermsLookup", Dict[str, Any], "DefaultType", diff --git a/elasticsearch/dsl/types.py b/elasticsearch/dsl/types.py index 7474769c6..772e596cd 100644 --- a/elasticsearch/dsl/types.py +++ b/elasticsearch/dsl/types.py @@ -26,34 +26,6 @@ PipeSeparatedFlags = str -class AggregationRange(AttrDict[Any]): - """ - :arg from: Start of the range (inclusive). - :arg key: Custom key to return the range with. - :arg to: End of the range (exclusive). - """ - - from_: Union[float, None, DefaultType] - key: Union[str, DefaultType] - to: Union[float, None, DefaultType] - - def __init__( - self, - *, - from_: Union[float, None, DefaultType] = DEFAULT, - key: Union[str, DefaultType] = DEFAULT, - to: Union[float, None, DefaultType] = DEFAULT, - **kwargs: Any, - ): - if from_ is not DEFAULT: - kwargs["from_"] = from_ - if key is not DEFAULT: - kwargs["key"] = key - if to is not DEFAULT: - kwargs["to"] = to - super().__init__(kwargs) - - class BucketCorrelationFunction(AttrDict[Any]): """ :arg count_correlation: (required) The configuration to calculate a @@ -334,34 +306,6 @@ def __init__( super().__init__(kwargs) -class DateRangeExpression(AttrDict[Any]): - """ - :arg from: Start of the range (inclusive). - :arg key: Custom key to return the range with. - :arg to: End of the range (exclusive). - """ - - from_: Union[str, float, DefaultType] - key: Union[str, DefaultType] - to: Union[str, float, DefaultType] - - def __init__( - self, - *, - from_: Union[str, float, DefaultType] = DEFAULT, - key: Union[str, DefaultType] = DEFAULT, - to: Union[str, float, DefaultType] = DEFAULT, - **kwargs: Any, - ): - if from_ is not DEFAULT: - kwargs["from_"] = from_ - if key is not DEFAULT: - kwargs["key"] = key - if to is not DEFAULT: - kwargs["to"] = to - super().__init__(kwargs) - - class DenseVectorIndexOptions(AttrDict[Any]): """ :arg type: (required) The type of kNN algorithm to use. @@ -591,6 +535,7 @@ class FieldSort(AttrDict[Any]): "completion", "nested", "object", + "passthrough", "version", "murmur3", "token_count", @@ -617,6 +562,7 @@ class FieldSort(AttrDict[Any]): "shape", "histogram", "constant_keyword", + "counted_keyword", "aggregate_metric_double", "dense_vector", "semantic_text", @@ -654,6 +600,7 @@ def __init__( "completion", "nested", "object", + "passthrough", "version", "murmur3", "token_count", @@ -680,6 +627,7 @@ def __init__( "shape", "histogram", "constant_keyword", + "counted_keyword", "aggregate_metric_double", "dense_vector", "semantic_text", @@ -2625,7 +2573,7 @@ class PercentageScoreHeuristic(AttrDict[Any]): class PinnedDoc(AttrDict[Any]): """ :arg _id: (required) The unique document ID. - :arg _index: (required) The index that contains the document. + :arg _index: The index that contains the document. """ _id: Union[str, DefaultType] @@ -2850,6 +2798,22 @@ def __init__( super().__init__(kwargs) +class RescoreVector(AttrDict[Any]): + """ + :arg oversample: (required) Applies the specified oversample factor to + k on the approximate kNN search + """ + + oversample: Union[float, DefaultType] + + def __init__( + self, *, oversample: Union[float, DefaultType] = DEFAULT, **kwargs: Any + ): + if oversample is not DEFAULT: + kwargs["oversample"] = oversample + super().__init__(kwargs) + + class ScoreSort(AttrDict[Any]): """ :arg order: @@ -2880,7 +2844,7 @@ class Script(AttrDict[Any]): :arg options: """ - source: Union[str, DefaultType] + source: Union[str, Dict[str, Any], DefaultType] id: Union[str, DefaultType] params: Union[Mapping[str, Any], DefaultType] lang: Union[Literal["painless", "expression", "mustache", "java"], DefaultType] @@ -2889,7 +2853,7 @@ class Script(AttrDict[Any]): def __init__( self, *, - source: Union[str, DefaultType] = DEFAULT, + source: Union[str, Dict[str, Any], DefaultType] = DEFAULT, id: Union[str, DefaultType] = DEFAULT, params: Union[Mapping[str, Any], DefaultType] = DEFAULT, lang: Union[ @@ -3488,14 +3452,14 @@ class SpanTermQuery(AttrDict[Any]): :arg _name: """ - value: Union[str, DefaultType] + value: Union[int, float, str, bool, None, DefaultType] boost: Union[float, DefaultType] _name: Union[str, DefaultType] def __init__( self, *, - value: Union[str, DefaultType] = DEFAULT, + value: Union[int, float, str, bool, None, DefaultType] = DEFAULT, boost: Union[float, DefaultType] = DEFAULT, _name: Union[str, DefaultType] = DEFAULT, **kwargs: Any, @@ -3613,7 +3577,7 @@ class TermQuery(AttrDict[Any]): :arg _name: """ - value: Union[int, float, str, bool, None, Any, DefaultType] + value: Union[int, float, str, bool, None, DefaultType] case_insensitive: Union[bool, DefaultType] boost: Union[float, DefaultType] _name: Union[str, DefaultType] @@ -3621,7 +3585,7 @@ class TermQuery(AttrDict[Any]): def __init__( self, *, - value: Union[int, float, str, bool, None, Any, DefaultType] = DEFAULT, + value: Union[int, float, str, bool, None, DefaultType] = DEFAULT, case_insensitive: Union[bool, DefaultType] = DEFAULT, boost: Union[float, DefaultType] = DEFAULT, _name: Union[str, DefaultType] = DEFAULT, @@ -3712,7 +3676,7 @@ class TermsSetQuery(AttrDict[Any]): :arg _name: """ - terms: Union[Sequence[str], DefaultType] + terms: Union[Sequence[Union[int, float, str, bool, None]], DefaultType] minimum_should_match: Union[int, str, DefaultType] minimum_should_match_field: Union[str, InstrumentedField, DefaultType] minimum_should_match_script: Union["Script", Dict[str, Any], DefaultType] @@ -3722,7 +3686,9 @@ class TermsSetQuery(AttrDict[Any]): def __init__( self, *, - terms: Union[Sequence[str], DefaultType] = DEFAULT, + terms: Union[ + Sequence[Union[int, float, str, bool, None]], DefaultType + ] = DEFAULT, minimum_should_match: Union[int, str, DefaultType] = DEFAULT, minimum_should_match_field: Union[ str, InstrumentedField, DefaultType @@ -4544,7 +4510,7 @@ class CompositeAggregate(AttrDict[Any]): :arg meta: """ - after_key: Mapping[str, Union[int, float, str, bool, None, Any]] + after_key: Mapping[str, Union[int, float, str, bool, None]] buckets: Sequence["CompositeBucket"] meta: Mapping[str, Any] @@ -4559,7 +4525,7 @@ class CompositeBucket(AttrDict[Any]): :arg doc_count: (required) """ - key: Mapping[str, Union[int, float, str, bool, None, Any]] + key: Mapping[str, Union[int, float, str, bool, None]] doc_count: int @@ -5235,9 +5201,7 @@ class Hit(AttrDict[Any]): matched_queries: Union[Sequence[str], Mapping[str, float]] nested: "NestedIdentity" ignored: Sequence[str] - ignored_field_values: Mapping[ - str, Sequence[Union[int, float, str, bool, None, Any]] - ] + ignored_field_values: Mapping[str, Sequence[Any]] shard: str node: str routing: str @@ -5246,7 +5210,7 @@ class Hit(AttrDict[Any]): seq_no: int primary_term: int version: int - sort: Sequence[Union[int, float, str, bool, None, Any]] + sort: Sequence[Union[int, float, str, bool, None]] class HitsMetadata(AttrDict[Any]): @@ -5271,7 +5235,7 @@ class InferenceAggregate(AttrDict[Any]): :arg meta: """ - value: Union[int, float, str, bool, None, Any] + value: Union[int, float, str, bool, None] feature_importance: Sequence["InferenceFeatureImportance"] top_classes: Sequence["InferenceTopClassEntry"] warning: str @@ -5307,7 +5271,7 @@ class InferenceTopClassEntry(AttrDict[Any]): :arg class_score: (required) """ - class_name: Union[int, float, str, bool, None, Any] + class_name: Union[int, float, str, bool, None] class_probability: float class_score: float @@ -5636,7 +5600,7 @@ class MultiTermsBucket(AttrDict[Any]): :arg doc_count_error_upper_bound: """ - key: Sequence[Union[int, float, str, bool, None, Any]] + key: Sequence[Union[int, float, str, bool, None]] doc_count: int key_as_string: str doc_count_error_upper_bound: int @@ -6187,7 +6151,7 @@ class StringTermsBucket(AttrDict[Any]): :arg doc_count_error_upper_bound: """ - key: Union[int, float, str, bool, None, Any] + key: Union[int, float, str, bool, None] doc_count: int doc_count_error_upper_bound: int @@ -6291,7 +6255,7 @@ class TimeSeriesBucket(AttrDict[Any]): :arg doc_count: (required) """ - key: Mapping[str, Union[int, float, str, bool, None, Any]] + key: Mapping[str, Union[int, float, str, bool, None]] doc_count: int @@ -6311,8 +6275,8 @@ class TopMetrics(AttrDict[Any]): :arg metrics: (required) """ - sort: Sequence[Union[Union[int, float, str, bool, None, Any], None]] - metrics: Mapping[str, Union[Union[int, float, str, bool, None, Any], None]] + sort: Sequence[Union[Union[int, float, str, bool, None], None]] + metrics: Mapping[str, Union[Union[int, float, str, bool, None], None]] class TopMetricsAggregate(AttrDict[Any]): diff --git a/elasticsearch/dsl/wrappers.py b/elasticsearch/dsl/wrappers.py index ecd2e1363..52e7af257 100644 --- a/elasticsearch/dsl/wrappers.py +++ b/elasticsearch/dsl/wrappers.py @@ -18,6 +18,7 @@ import operator from typing import ( TYPE_CHECKING, + Any, Callable, ClassVar, Dict, @@ -117,3 +118,27 @@ def lower(self) -> Union[Tuple[RangeValT, bool], Tuple[None, Literal[False]]]: if "gte" in self._d_: return self._d_["gte"], True return None, False + + +class AggregationRange(AttrDict[Any]): + """ + :arg from: Start of the range (inclusive). + :arg key: Custom key to return the range with. + :arg to: End of the range (exclusive). + """ + + def __init__( + self, + *, + from_: Any = None, + key: Optional[str] = None, + to: Any = None, + **kwargs: Any, + ): + if from_ is not None: + kwargs["from_"] = from_ + if key is not None: + kwargs["key"] = key + if to is not None: + kwargs["to"] = to + super().__init__(kwargs) diff --git a/utils/dsl-generator.py b/utils/dsl-generator.py index 3841967e7..2aa12c53d 100644 --- a/utils/dsl-generator.py +++ b/utils/dsl-generator.py @@ -17,14 +17,13 @@ import json import re +import subprocess import textwrap from urllib.error import HTTPError from urllib.request import urlopen from jinja2 import Environment, PackageLoader, select_autoescape -from elasticsearch import VERSION - jinja_env = Environment( loader=PackageLoader("utils"), autoescape=select_autoescape(), @@ -38,7 +37,7 @@ types_py = jinja_env.get_template("types.py.tpl") # map with name replacements for Elasticsearch attributes -PROP_REPLACEMENTS = {"from": "from_"} +PROP_REPLACEMENTS = {"from": "from_", "global": "global_"} # map with Elasticsearch type replacements # keys and values are in given in "{namespace}:{name}" format @@ -115,9 +114,9 @@ def type_for_types_py(type_): class ElasticsearchSchema: """Operations related to the Elasticsearch schema.""" - def __init__(self): + def __init__(self, version="main"): response = None - for branch in [f"{VERSION[0]}.{VERSION[1]}", "main"]: + for branch in [version, "main"]: url = f"https://raw.githubusercontent.com/elastic/elasticsearch-specification/{branch}/output/schema/schema.json" try: response = urlopen(url) @@ -201,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False): ): # QueryContainer maps to the DSL's Query class return "Query", {"type": "query"} + elif ( + type_name["namespace"] == "_global.search._types" + and type_name["name"] == "SearchRequestBody" + ): + # we currently do not provide specific typing for this one + return "Dict[str, Any]", None elif ( type_name["namespace"] == "_types.query_dsl" and type_name["name"] == "FunctionScoreContainer" @@ -219,7 +224,7 @@ def get_python_type(self, schema_type, for_response=False): type_name["namespace"] == "_types.aggregations" and type_name["name"] == "CompositeAggregationSource" ): - # QueryContainer maps to the DSL's Query class + # CompositeAggreagationSource maps to the DSL's Agg class return "Agg[_R]", None else: # for any other instances we get the type and recurse @@ -300,6 +305,8 @@ def get_python_type(self, schema_type, for_response=False): ] ) ) + if len(types) == 1: + return types[0] return "Union[" + ", ".join([type_ for type_, _ in types]) + "]", None elif schema_type["kind"] == "enum": @@ -338,6 +345,12 @@ def get_python_type(self, schema_type, for_response=False): ]["name"].endswith("Analyzer"): # not expanding analyzers at this time, maybe in the future return "str, Dict[str, Any]", None + elif ( + schema_type["name"]["namespace"] == "_types.aggregations" + and schema_type["name"]["name"].endswith("AggregationRange") + and schema_type["name"]["name"] != "IpRangeAggregationRange" + ): + return '"wrappers.AggregationRange"', None # to handle other interfaces we generate a type of the same name # and add the interface to the interfaces.py module @@ -380,9 +393,12 @@ def add_attribute(self, k, arg, for_types_py=False, for_response=False): param = None if not for_response: if type_ != "Any": - if 'Sequence["types.' in type_: + if ( + 'Sequence["types.' in type_ + or 'Sequence["wrappers.AggregationRange' in type_ + ): type_ = add_seq_dict_type(type_) # interfaces can be given as dicts - elif "types." in type_: + elif "types." in type_ or "wrappers.AggregationRange" in type_: type_ = add_dict_type(type_) # interfaces can be given as dicts type_ = add_not_set(type_) if for_types_py: @@ -999,7 +1015,8 @@ def generate_types_py(schema, filename): if __name__ == "__main__": - schema = ElasticsearchSchema() + v = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode() + schema = ElasticsearchSchema(v) generate_field_py(schema, "elasticsearch/dsl/field.py") generate_query_py(schema, "elasticsearch/dsl/query.py") generate_aggs_py(schema, "elasticsearch/dsl/aggs.py") diff --git a/utils/templates/aggs.py.tpl b/utils/templates/aggs.py.tpl index d4ba4f4cd..68d46e63d 100644 --- a/utils/templates/aggs.py.tpl +++ b/utils/templates/aggs.py.tpl @@ -38,6 +38,7 @@ from elastic_transport.client_utils import DEFAULT from .query import Query from .response.aggs import AggResponse, BucketData, FieldBucketData, TopHitsData from .utils import _R, AttrDict, DslBase +from . import wrappers if TYPE_CHECKING: from elastic_transport.client_utils import DefaultType