diff --git a/azure/functions/eventgrid.py b/azure/functions/eventgrid.py index e76f5dde..b476614f 100644 --- a/azure/functions/eventgrid.py +++ b/azure/functions/eventgrid.py @@ -52,11 +52,12 @@ class EventGridEventOutConverter(meta.OutConverter, binding="eventGrid"): def check_output_type_annotation(cls, pytype: type) -> bool: valid_types = (str, bytes, azf_eventgrid.EventGridOutputEvent, List[azf_eventgrid.EventGridOutputEvent]) - return (meta.is_iterable_type_annotation(pytype, str) or meta. - is_iterable_type_annotation(pytype, - azf_eventgrid.EventGridOutputEvent) + return (meta.is_iterable_type_annotation(pytype, str) + or meta.is_iterable_type_annotation( + pytype, + azf_eventgrid.EventGridOutputEvent) or (isinstance(pytype, type) - and issubclass(pytype, valid_types))) + and issubclass(pytype, valid_types))) @classmethod def encode(cls, obj: Any, *, expected_type: diff --git a/azure/functions/eventhub.py b/azure/functions/eventhub.py index 6aab2be5..0f3ff8f5 100644 --- a/azure/functions/eventhub.py +++ b/azure/functions/eventhub.py @@ -16,7 +16,8 @@ class EventHubConverter(meta.InConverter, meta.OutConverter, def check_input_type_annotation(cls, pytype: type) -> bool: valid_types = (_eventhub.EventHubEvent) return ( - meta.is_iterable_type_annotation(pytype, valid_types) + meta.is_supported_union_annotation(pytype, valid_types) + or meta.is_iterable_type_annotation(pytype, valid_types) or (isinstance(pytype, type) and issubclass(pytype, valid_types)) ) diff --git a/azure/functions/kafka.py b/azure/functions/kafka.py index 4693e9d1..a4bda36e 100644 --- a/azure/functions/kafka.py +++ b/azure/functions/kafka.py @@ -95,7 +95,8 @@ def check_input_type_annotation(cls, pytype) -> bool: valid_types = (KafkaEvent) return ( - meta.is_iterable_type_annotation(pytype, valid_types) + meta.is_supported_union_annotation(pytype, valid_types) + or meta.is_iterable_type_annotation(pytype, valid_types) or (isinstance(pytype, type) and issubclass(pytype, valid_types)) ) diff --git a/azure/functions/meta.py b/azure/functions/meta.py index 2ca92563..124c4acb 100644 --- a/azure/functions/meta.py +++ b/azure/functions/meta.py @@ -6,7 +6,12 @@ import datetime import json import re +import sys from typing import Dict, Optional, Union, Tuple, Mapping, Any +if sys.version_info >= (3, 9): + from typing import get_origin, get_args +else: + from ._thirdparty.typing_inspect import get_origin, get_args from ._thirdparty import typing_inspect from ._utils import ( @@ -37,6 +42,25 @@ def is_iterable_type_annotation(annotation: object, pytype: object) -> bool: for arg in args) +def is_supported_union_annotation(annotation: object, pytype) -> bool: + """Allows for Union annotation in function apps to be used as a type + hint, as long as the types in the Union are supported. This is + supported for bindings that allow for multiple types. + """ + origin = get_origin(annotation) + if origin is not Union: + return False + + args = get_args(annotation) + for arg in args: + supported = (is_iterable_type_annotation(arg, pytype) + or (isinstance(arg, type) and issubclass(arg, + pytype))) + if not supported: + return False + return True + + class Datum: def __init__(self, value: Any, type: Optional[str]): self.value: Any = value diff --git a/azure/functions/servicebus.py b/azure/functions/servicebus.py index 72a9d254..dc6a2f2e 100644 --- a/azure/functions/servicebus.py +++ b/azure/functions/servicebus.py @@ -228,7 +228,8 @@ class ServiceBusMessageInConverter(meta.InConverter, def check_input_type_annotation(cls, pytype: type) -> bool: valid_types = (azf_sbus.ServiceBusMessage) return ( - meta.is_iterable_type_annotation(pytype, valid_types) + meta.is_supported_union_annotation(pytype, valid_types) + or meta.is_iterable_type_annotation(pytype, valid_types) or (isinstance(pytype, type) and issubclass(pytype, valid_types)) ) diff --git a/tests/test_eventhub.py b/tests/test_eventhub.py index 737438f1..1a7c4501 100644 --- a/tests/test_eventhub.py +++ b/tests/test_eventhub.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - +import sys from typing import List, Mapping import unittest import json @@ -27,6 +27,18 @@ def test_eventhub_input_type(self): self.assertFalse(check_input_type(bytes)) self.assertFalse(check_input_type(List[str])) + @unittest.skipIf(sys.version_info < (3, 10), + reason="requires Python 3.10 or above") + def test_eventhub_input_type_above_310(self): + check_input_type = ( + azf_eh.EventHubConverter.check_input_type_annotation + ) + self.assertTrue(check_input_type( + func.EventHubEvent | List[func.EventHubEvent])) + self.assertFalse(check_input_type(func.EventHubEvent | List[str])) + self.assertFalse(check_input_type(str | List[func.EventHubEvent])) + self.assertFalse(check_input_type(str | List[str])) + def test_eventhub_output_type(self): check_output_type = ( azf_eh.EventHubTriggerConverter.check_output_type_annotation diff --git a/tests/test_kafka.py b/tests/test_kafka.py index 6776c946..9bd8bd22 100644 --- a/tests/test_kafka.py +++ b/tests/test_kafka.py @@ -4,6 +4,7 @@ from typing import List import unittest import json +import sys from unittest.mock import patch import azure.functions as func @@ -40,6 +41,19 @@ def test_kafka_input_type(self): self.assertFalse(check_input_type(bytes)) self.assertFalse(check_input_type(List[str])) + @unittest.skipIf(sys.version_info < (3, 10), + reason="requires Python 3.10 or above") + def test_kafka_input_type_above_310(self): + check_input_type = ( + azf_ka.KafkaConverter.check_input_type_annotation + ) + + self.assertTrue(check_input_type( + func.KafkaEvent | List[func.KafkaEvent])) + self.assertFalse(check_input_type(func.KafkaEvent | List[str])) + self.assertFalse(check_input_type(str | List[func.KafkaEvent])) + self.assertFalse(check_input_type(str | List[str])) + def test_kafka_output_type(self): check_output_type = ( azf_ka.KafkaTriggerConverter.check_output_type_annotation diff --git a/tests/test_servicebus.py b/tests/test_servicebus.py index b02635de..673db9e1 100644 --- a/tests/test_servicebus.py +++ b/tests/test_servicebus.py @@ -3,6 +3,7 @@ from typing import Dict, List import json +import sys import unittest from datetime import datetime, timedelta, date @@ -231,6 +232,19 @@ class ServiceBusMessageChild(func.ServiceBusMessage): self.assertFalse(check_input_type(str)) self.assertFalse(check_input_type(type(None))) + @unittest.skipIf(sys.version_info < (3, 10), + reason="requires Python 3.10 or above") + def test_servicebus_input_type_above_310(self): + check_input_type = ( + azf_sb.ServiceBusMessageInConverter.check_input_type_annotation + ) + + self.assertTrue(check_input_type( + func.ServiceBusMessage | List[func.ServiceBusMessage])) + self.assertFalse(check_input_type(func.ServiceBusMessage | List[str])) + self.assertFalse(check_input_type(str | List[func.ServiceBusMessage])) + self.assertFalse(check_input_type(str | List[str])) + def test_servicebus_output_type(self): check_output_type = ( azf_sb.ServiceBusMessageOutConverter.check_output_type_annotation