From 9adb1b2822bd124727e261c2f8b376d526c142cd Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 30 Aug 2023 14:24:15 -0700 Subject: [PATCH 1/5] enable azure for audio --- openai/api_resources/audio.py | 382 ++++++++++++++++++++++++++- openai/tests/test_audio_overloads.py | 175 ++++++++++++ 2 files changed, 547 insertions(+), 10 deletions(-) create mode 100644 openai/tests/test_audio_overloads.py diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py index d5d906ed96..2fe321d6d9 100644 --- a/openai/api_resources/audio.py +++ b/openai/api_resources/audio.py @@ -1,15 +1,33 @@ -from typing import Any, List +from typing import Any, List, overload import openai from openai import api_requestor, util from openai.api_resources.abstract import APIResource +def check_required(*args, method_name, required, **kwargs): + missing = [] + args_count = len(args) + for param in required: + if param in kwargs: + continue + elif args_count > 0: + args_count -= 1 + continue + else: + missing.append(param) + + if missing and "deployment_id" not in kwargs: + raise TypeError(f"{method_name}() missing {len(missing)} required positional argument(s): {', '.join(missing)}") + + class Audio(APIResource): OBJECT_NAME = "audio" @classmethod - def _get_url(cls, action): + def _get_url(cls, action, deployment_id=None, api_type=None, api_version=None): + if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): + return f"/{cls.azure_api_prefix}/deployments/{deployment_id}/audio/{action}?api-version={api_version}" return cls.class_url() + f"/{action}" @classmethod @@ -40,6 +58,23 @@ def _prepare_request( files.append(("file", (filename, file, "application/octet-stream"))) return requestor, files, data + @overload + @classmethod + def transcribe( + cls, + *, + deployment_id=None, + file=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod def transcribe( cls, @@ -52,6 +87,30 @@ def transcribe( organization=None, **params, ): + ... + + @classmethod + def transcribe( + cls, + *args, + **params, + ): + if len(args) > 7: + raise TypeError( + f"transcribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="transcribe", required=["model", "file"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -63,12 +122,30 @@ def transcribe( organization=organization, **params, ) - url = cls._get_url("transcriptions") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = requestor.request("post", url, files=files, params=data) return util.convert_to_openai_object( response, api_key, api_version, organization ) + @overload + @classmethod + def translate( + cls, + *, + deployment_id=None, + file=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod def translate( cls, @@ -81,6 +158,30 @@ def translate( organization=None, **params, ): + ... + + @classmethod + def translate( + cls, + *args, + **params, + ): + if len(args) > 7: + raise TypeError( + f"translate() takes from 3 to 8 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="translate", required=["model", "file"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -92,12 +193,31 @@ def translate( organization=organization, **params, ) - url = cls._get_url("translations") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = requestor.request("post", url, files=files, params=data) return util.convert_to_openai_object( response, api_key, api_version, organization ) + @overload + @classmethod + def transcribe_raw( + cls, + *, + deployment_id=None, + file=None, + filename=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod def transcribe_raw( cls, @@ -111,6 +231,31 @@ def transcribe_raw( organization=None, **params, ): + ... + + @classmethod + def transcribe_raw( + cls, + *args, + **params, + ): + if len(args) > 8: + raise TypeError( + f"transcribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="transcribe_raw", required=["model", "file", "filename"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + filename = positional.pop(0) if positional else params.pop("filename", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -122,12 +267,31 @@ def transcribe_raw( organization=organization, **params, ) - url = cls._get_url("transcriptions") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = requestor.request("post", url, files=files, params=data) return util.convert_to_openai_object( response, api_key, api_version, organization ) + @overload + @classmethod + def translate_raw( + cls, + *, + deployment_id=None, + file=None, + filename=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod def translate_raw( cls, @@ -141,6 +305,31 @@ def translate_raw( organization=None, **params, ): + ... + + @classmethod + def translate_raw( + cls, + *args, + **params, + ): + if len(args) > 8: + raise TypeError( + f"translate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="translate_raw", required=["model", "file", "filename"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + filename = positional.pop(0) if positional else params.pop("filename", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -152,12 +341,30 @@ def translate_raw( organization=organization, **params, ) - url = cls._get_url("translations") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = requestor.request("post", url, files=files, params=data) return util.convert_to_openai_object( response, api_key, api_version, organization ) + @overload + @classmethod + async def atranscribe( + cls, + *, + deployment_id=None, + file=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod async def atranscribe( cls, @@ -170,6 +377,30 @@ async def atranscribe( organization=None, **params, ): + ... + + @classmethod + async def atranscribe( + cls, + *args, + **params, + ): + if len(args) > 7: + raise TypeError( + f"atranscribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="atranscribe", required=["model", "file"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -181,7 +412,8 @@ async def atranscribe( organization=organization, **params, ) - url = cls._get_url("transcriptions") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = await requestor.arequest( "post", url, files=files, params=data ) @@ -189,6 +421,23 @@ async def atranscribe( response, api_key, api_version, organization ) + @overload + @classmethod + async def atranslate( + cls, + *, + deployment_id=None, + file=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod async def atranslate( cls, @@ -201,6 +450,30 @@ async def atranslate( organization=None, **params, ): + ... + + @classmethod + async def atranslate( + cls, + *args, + **params, + ): + if len(args) > 7: + raise TypeError( + f"atranslate() takes from 3 to 8 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="atranslate", required=["model", "file"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -212,7 +485,8 @@ async def atranslate( organization=organization, **params, ) - url = cls._get_url("translations") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = await requestor.arequest( "post", url, files=files, params=data ) @@ -220,6 +494,24 @@ async def atranslate( response, api_key, api_version, organization ) + @overload + @classmethod + async def atranscribe_raw( + cls, + *, + deployment_id=None, + file=None, + filename=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod async def atranscribe_raw( cls, @@ -233,6 +525,31 @@ async def atranscribe_raw( organization=None, **params, ): + ... + + @classmethod + async def atranscribe_raw( + cls, + *args, + **params, + ): + if len(args) > 8: + raise TypeError( + f"atranscribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="atranscribe_raw", required=["model", "file", "filename"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + filename = positional.pop(0) if positional else params.pop("filename", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -244,7 +561,8 @@ async def atranscribe_raw( organization=organization, **params, ) - url = cls._get_url("transcriptions") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = await requestor.arequest( "post", url, files=files, params=data ) @@ -252,6 +570,24 @@ async def atranscribe_raw( response, api_key, api_version, organization ) + @overload + @classmethod + async def atranslate_raw( + cls, + *, + deployment_id=None, + file=None, + filename=None, + api_key=None, + api_base=None, + api_type=None, + api_version=None, + organization=None, + **params, + ): + ... + + @overload @classmethod async def atranslate_raw( cls, @@ -265,6 +601,31 @@ async def atranslate_raw( organization=None, **params, ): + ... + + @classmethod + async def atranslate_raw( + cls, + *args, + **params, + ): + if len(args) > 8: + raise TypeError( + f"atranslate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" + ) + check_required(*args, method_name="atranslate_raw", required=["model", "file", "filename"], **params) + + positional = list(args) + model = positional.pop(0) if positional else params.pop("model", None) + file = positional.pop(0) if positional else params.pop("file", None) + filename = positional.pop(0) if positional else params.pop("filename", None) + api_key = positional.pop(0) if positional else params.pop("api_key", None) + api_base = positional.pop(0) if positional else params.pop("api_base", None) + api_type = positional.pop(0) if positional else params.pop("api_type", None) + api_version = positional.pop(0) if positional else params.pop("api_version", None) + organization = positional.pop(0) if positional else params.pop("organization", None) + deployment_id = params.pop("deployment_id", None) + requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -276,7 +637,8 @@ async def atranslate_raw( organization=organization, **params, ) - url = cls._get_url("translations") + api_type, api_version = cls._get_api_type_and_version(api_type, api_version) + url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version) response, _, api_key = await requestor.arequest( "post", url, files=files, params=data ) diff --git a/openai/tests/test_audio_overloads.py b/openai/tests/test_audio_overloads.py new file mode 100644 index 0000000000..9cce4fdc9e --- /dev/null +++ b/openai/tests/test_audio_overloads.py @@ -0,0 +1,175 @@ +import openai +import pytest + + +API_BASE = "" +AZURE_API_KEY = "" +OPENAI_API_KEY = "" +API_VERSION = "" +AUDIO_FILE_PATH = "" + + +def test_transcribe(): + + # Invalid + with pytest.raises(TypeError) as e: + openai.Audio.transcribe( + "whisper-1", + open(AUDIO_FILE_PATH, "rb"), + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "transcribe() takes from 3 to 8 positional arguments but 9 were given" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe() + assert str(e.value) == "transcribe() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe( + "whisper-1" + ) + assert str(e.value) == "transcribe() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe( + model="whisper-1" + ) + assert str(e.value) == "transcribe() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe( + file=open(AUDIO_FILE_PATH, "rb") + ) + assert str(e.value) == "transcribe() missing 1 required positional argument(s): model" + + # Valid + openai.api_key = OPENAI_API_KEY + audio = openai.Audio.transcribe( + "whisper-1", + open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + audio = openai.Audio.transcribe( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = openai.Audio.transcribe( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + +def test_transcribe_raw(): + + # Invalid + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + "filename", + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "transcribe_raw() takes from 4 to 9 positional arguments but 10 were given" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw() + assert str(e.value) == "transcribe_raw() missing 3 required positional argument(s): model, file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + "whisper-1" + ) + assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + model="whisper-1" + ) + assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): model, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + filename="recording.m4a" + ) + assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + model="whisper-1", + filename="recording.m4a" + ) + assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): file" + + + with pytest.raises(TypeError) as e: + openai.Audio.transcribe_raw( + file=open(AUDIO_FILE_PATH, "rb").read(), + filename="recording.m4a" + ) + assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): model" + + + # Valid + openai.api_key = OPENAI_API_KEY + audio = openai.Audio.transcribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + "recording.m4a" + ) + assert audio + + audio = openai.Audio.transcribe_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename="recording.m4a" + ) + assert audio + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = openai.Audio.transcribe_raw( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename="recording.m4a" + ) + assert audio From 310e41551bf56b9ff8c7497871ce420d5c0dbf2b Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 30 Aug 2023 14:31:05 -0700 Subject: [PATCH 2/5] reorder overloads --- openai/api_resources/audio.py | 101 +++++++++++++++++----------------- 1 file changed, 51 insertions(+), 50 deletions(-) diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py index 2fe321d6d9..68de0e9080 100644 --- a/openai/api_resources/audio.py +++ b/openai/api_resources/audio.py @@ -14,8 +14,7 @@ def check_required(*args, method_name, required, **kwargs): elif args_count > 0: args_count -= 1 continue - else: - missing.append(param) + missing.append(param) if missing and "deployment_id" not in kwargs: raise TypeError(f"{method_name}() missing {len(missing)} required positional argument(s): {', '.join(missing)}") @@ -62,9 +61,8 @@ def _prepare_request( @classmethod def transcribe( cls, - *, - deployment_id=None, - file=None, + model, + file, api_key=None, api_base=None, api_type=None, @@ -78,8 +76,9 @@ def transcribe( @classmethod def transcribe( cls, - model, - file, + *, + deployment_id=None, + file=None, api_key=None, api_base=None, api_type=None, @@ -133,9 +132,8 @@ def transcribe( @classmethod def translate( cls, - *, - deployment_id=None, - file=None, + model, + file, api_key=None, api_base=None, api_type=None, @@ -145,12 +143,14 @@ def translate( ): ... + @overload @classmethod def translate( cls, - model, - file, + *, + deployment_id=None, + file=None, api_key=None, api_base=None, api_type=None, @@ -204,10 +204,9 @@ def translate( @classmethod def transcribe_raw( cls, - *, - deployment_id=None, - file=None, - filename=None, + model, + file, + filename, api_key=None, api_base=None, api_type=None, @@ -217,13 +216,15 @@ def transcribe_raw( ): ... + @overload @classmethod def transcribe_raw( cls, - model, - file, - filename, + *, + deployment_id=None, + file=None, + filename=None, api_key=None, api_base=None, api_type=None, @@ -278,10 +279,9 @@ def transcribe_raw( @classmethod def translate_raw( cls, - *, - deployment_id=None, - file=None, - filename=None, + model, + file, + filename, api_key=None, api_base=None, api_type=None, @@ -295,9 +295,10 @@ def translate_raw( @classmethod def translate_raw( cls, - model, - file, - filename, + *, + deployment_id=None, + file=None, + filename=None, api_key=None, api_base=None, api_type=None, @@ -352,9 +353,8 @@ def translate_raw( @classmethod async def atranscribe( cls, - *, - deployment_id=None, - file=None, + model, + file, api_key=None, api_base=None, api_type=None, @@ -368,8 +368,9 @@ async def atranscribe( @classmethod async def atranscribe( cls, - model, - file, + *, + deployment_id=None, + file=None, api_key=None, api_base=None, api_type=None, @@ -425,9 +426,8 @@ async def atranscribe( @classmethod async def atranslate( cls, - *, - deployment_id=None, - file=None, + model, + file, api_key=None, api_base=None, api_type=None, @@ -441,8 +441,9 @@ async def atranslate( @classmethod async def atranslate( cls, - model, - file, + *, + deployment_id=None, + file=None, api_key=None, api_base=None, api_type=None, @@ -498,10 +499,9 @@ async def atranslate( @classmethod async def atranscribe_raw( cls, - *, - deployment_id=None, - file=None, - filename=None, + model, + file, + filename, api_key=None, api_base=None, api_type=None, @@ -515,9 +515,10 @@ async def atranscribe_raw( @classmethod async def atranscribe_raw( cls, - model, - file, - filename, + *, + deployment_id=None, + file=None, + filename=None, api_key=None, api_base=None, api_type=None, @@ -574,10 +575,9 @@ async def atranscribe_raw( @classmethod async def atranslate_raw( cls, - *, - deployment_id=None, - file=None, - filename=None, + model, + file, + filename, api_key=None, api_base=None, api_type=None, @@ -591,9 +591,10 @@ async def atranslate_raw( @classmethod async def atranslate_raw( cls, - model, - file, - filename, + *, + deployment_id=None, + file=None, + filename=None, api_key=None, api_base=None, api_type=None, From d66a949809232c47b329b04b9406c135b9d274ed Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 11 Sep 2023 16:12:09 -0700 Subject: [PATCH 3/5] add additional tests --- openai/api_resources/audio.py | 2 +- openai/tests/test_audio_overloads.py | 543 ++++++++++++++++++++++++++- 2 files changed, 538 insertions(+), 7 deletions(-) diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py index 68de0e9080..d15ce7fd05 100644 --- a/openai/api_resources/audio.py +++ b/openai/api_resources/audio.py @@ -11,7 +11,7 @@ def check_required(*args, method_name, required, **kwargs): for param in required: if param in kwargs: continue - elif args_count > 0: + if args_count > 0: args_count -= 1 continue missing.append(param) diff --git a/openai/tests/test_audio_overloads.py b/openai/tests/test_audio_overloads.py index 9cce4fdc9e..0433d13a45 100644 --- a/openai/tests/test_audio_overloads.py +++ b/openai/tests/test_audio_overloads.py @@ -7,8 +7,10 @@ OPENAI_API_KEY = "" API_VERSION = "" AUDIO_FILE_PATH = "" +AUDIO_FILE_NAME = "" +# TRANSCRIBE ----------------------------------------------------------------------------------- def test_transcribe(): # Invalid @@ -48,7 +50,10 @@ def test_transcribe(): assert str(e.value) == "transcribe() missing 1 required positional argument(s): model" # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" openai.api_key = OPENAI_API_KEY + openai.api_version = None audio = openai.Audio.transcribe( "whisper-1", open(AUDIO_FILE_PATH, "rb") @@ -67,8 +72,76 @@ def test_transcribe(): openai.api_version = API_VERSION audio = openai.Audio.transcribe( deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb"), + response_format="verbose_json" + ) + assert audio + + +@pytest.mark.asyncio +async def test_atranscribe(): + + # Invalid + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe( + "whisper-1", + open(AUDIO_FILE_PATH, "rb"), + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "atranscribe() takes from 3 to 8 positional arguments but 9 were given" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe() + assert str(e.value) == "atranscribe() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe( + "whisper-1" + ) + assert str(e.value) == "atranscribe() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe( + model="whisper-1" + ) + assert str(e.value) == "atranscribe() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe( + file=open(AUDIO_FILE_PATH, "rb") + ) + assert str(e.value) == "atranscribe() missing 1 required positional argument(s): model" + + # # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = await openai.Audio.atranscribe( + "whisper-1", + open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + audio1 = await openai.Audio.atranscribe( + model="whisper-1", file=open(AUDIO_FILE_PATH, "rb") ) + assert audio1 + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = await openai.Audio.atranscribe( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb"), + ) assert audio @@ -120,7 +193,7 @@ def test_transcribe_raw(): with pytest.raises(TypeError) as e: openai.Audio.transcribe_raw( - filename="recording.m4a" + filename=AUDIO_FILE_NAME ) assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): model, file" @@ -134,7 +207,7 @@ def test_transcribe_raw(): with pytest.raises(TypeError) as e: openai.Audio.transcribe_raw( model="whisper-1", - filename="recording.m4a" + filename=AUDIO_FILE_NAME ) assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): file" @@ -142,24 +215,27 @@ def test_transcribe_raw(): with pytest.raises(TypeError) as e: openai.Audio.transcribe_raw( file=open(AUDIO_FILE_PATH, "rb").read(), - filename="recording.m4a" + filename=AUDIO_FILE_NAME ) assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): model" # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" openai.api_key = OPENAI_API_KEY + openai.api_version = None audio = openai.Audio.transcribe_raw( "whisper-1", open(AUDIO_FILE_PATH, "rb").read(), - "recording.m4a" + AUDIO_FILE_NAME ) assert audio audio = openai.Audio.transcribe_raw( model="whisper-1", file=open(AUDIO_FILE_PATH, "rb").read(), - filename="recording.m4a" + filename=AUDIO_FILE_NAME ) assert audio @@ -170,6 +246,461 @@ def test_transcribe_raw(): audio = openai.Audio.transcribe_raw( deployment_id="whisper-1", file=open(AUDIO_FILE_PATH, "rb").read(), - filename="recording.m4a" + filename=AUDIO_FILE_NAME + ) + assert audio + + +@pytest.mark.asyncio +async def test_atranscribe_raw(): + + # Invalid + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + "filename", + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "atranscribe_raw() takes from 4 to 9 positional arguments but 10 were given" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw() + assert str(e.value) == "atranscribe_raw() missing 3 required positional argument(s): model, file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + "whisper-1" + ) + assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + model="whisper-1" + ) + assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): model, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + model="whisper-1", + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): file" + + + with pytest.raises(TypeError) as e: + await openai.Audio.atranscribe_raw( + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): model" + + + # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = await openai.Audio.atranscribe_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + AUDIO_FILE_NAME + ) + assert audio + + audio = await openai.Audio.atranscribe_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert audio + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = await openai.Audio.atranscribe_raw( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert audio + + +# TRANSLATE ----------------------------------------------------------------------------------- + +def test_translate(): + + # Invalid + with pytest.raises(TypeError) as e: + openai.Audio.translate( + "whisper-1", + open(AUDIO_FILE_PATH, "rb"), + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "translate() takes from 3 to 8 positional arguments but 9 were given" + + with pytest.raises(TypeError) as e: + openai.Audio.translate() + assert str(e.value) == "translate() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + openai.Audio.translate( + "whisper-1" + ) + assert str(e.value) == "translate() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + openai.Audio.translate( + model="whisper-1" + ) + assert str(e.value) == "translate() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + openai.Audio.translate( + file=open(AUDIO_FILE_PATH, "rb") + ) + assert str(e.value) == "translate() missing 1 required positional argument(s): model" + + # # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = openai.Audio.translate( + "whisper-1", + open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + audio1 = openai.Audio.translate( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb") + ) + assert audio1 + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = openai.Audio.translate( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb"), + ) + assert audio + + +@pytest.mark.asyncio +async def test_atranslate(): + + # Invalid + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate( + "whisper-1", + open(AUDIO_FILE_PATH, "rb"), + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "atranslate() takes from 3 to 8 positional arguments but 9 were given" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate() + assert str(e.value) == "atranslate() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate( + "whisper-1" + ) + assert str(e.value) == "atranslate() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate( + model="whisper-1" + ) + assert str(e.value) == "atranslate() missing 1 required positional argument(s): file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate( + file=open(AUDIO_FILE_PATH, "rb") + ) + assert str(e.value) == "atranslate() missing 1 required positional argument(s): model" + + # # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = await openai.Audio.atranslate( + "whisper-1", + open(AUDIO_FILE_PATH, "rb") + ) + assert audio + + audio1 = await openai.Audio.atranslate( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb") + ) + assert audio1 + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = await openai.Audio.atranslate( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb"), + ) + assert audio + + +def test_translate_raw(): + + # Invalid + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + "filename", + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "translate_raw() takes from 4 to 9 positional arguments but 10 were given" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw() + assert str(e.value) == "translate_raw() missing 3 required positional argument(s): model, file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + "whisper-1" + ) + assert str(e.value) == "translate_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "translate_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + model="whisper-1" + ) + assert str(e.value) == "translate_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "translate_raw() missing 2 required positional argument(s): model, filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "translate_raw() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "translate_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + model="whisper-1", + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "translate_raw() missing 1 required positional argument(s): file" + + + with pytest.raises(TypeError) as e: + openai.Audio.translate_raw( + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "translate_raw() missing 1 required positional argument(s): model" + + + # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = openai.Audio.translate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + AUDIO_FILE_NAME + ) + assert audio + + audio = openai.Audio.translate_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert audio + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = openai.Audio.translate_raw( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert audio + + +@pytest.mark.asyncio +async def test_atranslate_raw(): + + # Invalid + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + "filename", + "api_key", + "api_base", + "api_type", + "api_version", + "organization", + "extra", + ) + assert str(e.value) == "atranslate_raw() takes from 4 to 9 positional arguments but 10 were given" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw() + assert str(e.value) == "atranslate_raw() missing 3 required positional argument(s): model, file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + "whisper-1" + ) + assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + model="whisper-1" + ) + assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): file, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): model, filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): model, file" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read() + ) + assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): filename" + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + model="whisper-1", + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): file" + + + with pytest.raises(TypeError) as e: + await openai.Audio.atranslate_raw( + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): model" + + + # Valid + openai.api_base = "https://api.openai.com/v1" + openai.api_type = "openai" + openai.api_key = OPENAI_API_KEY + openai.api_version = None + audio = await openai.Audio.atranslate_raw( + "whisper-1", + open(AUDIO_FILE_PATH, "rb").read(), + AUDIO_FILE_NAME + ) + assert audio + + audio = await openai.Audio.atranslate_raw( + model="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME + ) + assert audio + + openai.api_base = API_BASE + openai.api_key = AZURE_API_KEY + openai.api_type = "azure" + openai.api_version = API_VERSION + audio = await openai.Audio.atranslate_raw( + deployment_id="whisper-1", + file=open(AUDIO_FILE_PATH, "rb").read(), + filename=AUDIO_FILE_NAME ) assert audio From 021c95b04d5263059da2458587ab1ea8b67af314 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 12 Sep 2023 10:45:25 -0700 Subject: [PATCH 4/5] add helper function to utils --- openai/api_resources/audio.py | 31 +++++++--------------------- openai/tests/test_audio_overloads.py | 1 - openai/util.py | 18 ++++++++++++++++ 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py index d15ce7fd05..76cc1d36ee 100644 --- a/openai/api_resources/audio.py +++ b/openai/api_resources/audio.py @@ -5,21 +5,6 @@ from openai.api_resources.abstract import APIResource -def check_required(*args, method_name, required, **kwargs): - missing = [] - args_count = len(args) - for param in required: - if param in kwargs: - continue - if args_count > 0: - args_count -= 1 - continue - missing.append(param) - - if missing and "deployment_id" not in kwargs: - raise TypeError(f"{method_name}() missing {len(missing)} required positional argument(s): {', '.join(missing)}") - - class Audio(APIResource): OBJECT_NAME = "audio" @@ -98,7 +83,7 @@ def transcribe( raise TypeError( f"transcribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="transcribe", required=["model", "file"], **params) + util.check_required(*args, method_name="transcribe", required=["model", "file"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -170,7 +155,7 @@ def translate( raise TypeError( f"translate() takes from 3 to 8 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="translate", required=["model", "file"], **params) + util.check_required(*args, method_name="translate", required=["model", "file"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -244,7 +229,7 @@ def transcribe_raw( raise TypeError( f"transcribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="transcribe_raw", required=["model", "file", "filename"], **params) + util.check_required(*args, method_name="transcribe_raw", required=["model", "file", "filename"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -318,7 +303,7 @@ def translate_raw( raise TypeError( f"translate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="translate_raw", required=["model", "file", "filename"], **params) + util.check_required(*args, method_name="translate_raw", required=["model", "file", "filename"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -390,7 +375,7 @@ async def atranscribe( raise TypeError( f"atranscribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="atranscribe", required=["model", "file"], **params) + util.check_required(*args, method_name="atranscribe", required=["model", "file"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -463,7 +448,7 @@ async def atranslate( raise TypeError( f"atranslate() takes from 3 to 8 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="atranslate", required=["model", "file"], **params) + util.check_required(*args, method_name="atranslate", required=["model", "file"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -538,7 +523,7 @@ async def atranscribe_raw( raise TypeError( f"atranscribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="atranscribe_raw", required=["model", "file", "filename"], **params) + util.check_required(*args, method_name="atranscribe_raw", required=["model", "file", "filename"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) @@ -614,7 +599,7 @@ async def atranslate_raw( raise TypeError( f"atranslate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" ) - check_required(*args, method_name="atranslate_raw", required=["model", "file", "filename"], **params) + util.check_required(*args, method_name="atranslate_raw", required=["model", "file", "filename"], **params) positional = list(args) model = positional.pop(0) if positional else params.pop("model", None) diff --git a/openai/tests/test_audio_overloads.py b/openai/tests/test_audio_overloads.py index 0433d13a45..56bf2ebcf2 100644 --- a/openai/tests/test_audio_overloads.py +++ b/openai/tests/test_audio_overloads.py @@ -73,7 +73,6 @@ def test_transcribe(): audio = openai.Audio.transcribe( deployment_id="whisper-1", file=open(AUDIO_FILE_PATH, "rb"), - response_format="verbose_json" ) assert audio diff --git a/openai/util.py b/openai/util.py index 5501d5b67e..5d16092a1b 100644 --- a/openai/util.py +++ b/openai/util.py @@ -186,3 +186,21 @@ def default_api_key() -> str: raise openai.error.AuthenticationError( "No API key provided. You can set your API key in code using 'openai.api_key = ', or you can set the environment variable OPENAI_API_KEY=). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = '. You can generate API keys in the OpenAI web interface. See https://platform.openai.com/account/api-keys for details." ) + + +def check_required(*args, method_name, required, **kwargs): + """Checks that all required parameters have been provided + to the method where overloads are used to maintain existing behavior. + """ + missing = [] + args_count = len(args) + for param in required: + if param in kwargs: + continue + if args_count > 0: + args_count -= 1 + continue + missing.append(param) + + if missing and "deployment_id" not in kwargs: + raise TypeError(f"{method_name}() missing {len(missing)} required positional argument(s): {', '.join(missing)}") From 8ae05cf8b146c4023f29aa8e6018a766ac97c813 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 21 Sep 2023 12:19:06 -0700 Subject: [PATCH 5/5] simplify - azure users will just need to pass model and deployment_id --- openai/api_resources/audio.py | 324 +----------- openai/tests/test_audio_overloads.py | 705 --------------------------- openai/util.py | 18 - 3 files changed, 1 insertion(+), 1046 deletions(-) delete mode 100644 openai/tests/test_audio_overloads.py diff --git a/openai/api_resources/audio.py b/openai/api_resources/audio.py index 76cc1d36ee..cb316f07f1 100644 --- a/openai/api_resources/audio.py +++ b/openai/api_resources/audio.py @@ -1,4 +1,4 @@ -from typing import Any, List, overload +from typing import Any, List import openai from openai import api_requestor, util @@ -42,7 +42,6 @@ def _prepare_request( files.append(("file", (filename, file, "application/octet-stream"))) return requestor, files, data - @overload @classmethod def transcribe( cls, @@ -53,48 +52,10 @@ def transcribe( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - def transcribe( - cls, *, deployment_id=None, - file=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, - **params, - ): - ... - - @classmethod - def transcribe( - cls, - *args, **params, ): - if len(args) > 7: - raise TypeError( - f"transcribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="transcribe", required=["model", "file"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -113,7 +74,6 @@ def transcribe( response, api_key, api_version, organization ) - @overload @classmethod def translate( cls, @@ -124,49 +84,10 @@ def translate( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - - @overload - @classmethod - def translate( - cls, *, deployment_id=None, - file=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - def translate( - cls, - *args, - **params, - ): - if len(args) > 7: - raise TypeError( - f"translate() takes from 3 to 8 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="translate", required=["model", "file"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -185,7 +106,6 @@ def translate( response, api_key, api_version, organization ) - @overload @classmethod def transcribe_raw( cls, @@ -197,51 +117,10 @@ def transcribe_raw( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - - @overload - @classmethod - def transcribe_raw( - cls, *, deployment_id=None, - file=None, - filename=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - def transcribe_raw( - cls, - *args, - **params, - ): - if len(args) > 8: - raise TypeError( - f"transcribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="transcribe_raw", required=["model", "file", "filename"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - filename = positional.pop(0) if positional else params.pop("filename", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -260,7 +139,6 @@ def transcribe_raw( response, api_key, api_version, organization ) - @overload @classmethod def translate_raw( cls, @@ -272,50 +150,10 @@ def translate_raw( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - def translate_raw( - cls, *, deployment_id=None, - file=None, - filename=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - def translate_raw( - cls, - *args, - **params, - ): - if len(args) > 8: - raise TypeError( - f"translate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="translate_raw", required=["model", "file", "filename"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - filename = positional.pop(0) if positional else params.pop("filename", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -334,7 +172,6 @@ def translate_raw( response, api_key, api_version, organization ) - @overload @classmethod async def atranscribe( cls, @@ -345,48 +182,10 @@ async def atranscribe( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - async def atranscribe( - cls, *, deployment_id=None, - file=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, - **params, - ): - ... - - @classmethod - async def atranscribe( - cls, - *args, **params, ): - if len(args) > 7: - raise TypeError( - f"atranscribe() takes from 3 to 8 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="atranscribe", required=["model", "file"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -407,7 +206,6 @@ async def atranscribe( response, api_key, api_version, organization ) - @overload @classmethod async def atranslate( cls, @@ -418,48 +216,10 @@ async def atranslate( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - async def atranslate( - cls, *, deployment_id=None, - file=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - async def atranslate( - cls, - *args, - **params, - ): - if len(args) > 7: - raise TypeError( - f"atranslate() takes from 3 to 8 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="atranslate", required=["model", "file"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=file.name, @@ -480,7 +240,6 @@ async def atranslate( response, api_key, api_version, organization ) - @overload @classmethod async def atranscribe_raw( cls, @@ -492,50 +251,10 @@ async def atranscribe_raw( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - async def atranscribe_raw( - cls, *, deployment_id=None, - file=None, - filename=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - async def atranscribe_raw( - cls, - *args, - **params, - ): - if len(args) > 8: - raise TypeError( - f"atranscribe_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="atranscribe_raw", required=["model", "file", "filename"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - filename = positional.pop(0) if positional else params.pop("filename", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=filename, @@ -556,7 +275,6 @@ async def atranscribe_raw( response, api_key, api_version, organization ) - @overload @classmethod async def atranslate_raw( cls, @@ -568,50 +286,10 @@ async def atranslate_raw( api_type=None, api_version=None, organization=None, - **params, - ): - ... - - @overload - @classmethod - async def atranslate_raw( - cls, *, deployment_id=None, - file=None, - filename=None, - api_key=None, - api_base=None, - api_type=None, - api_version=None, - organization=None, **params, ): - ... - - @classmethod - async def atranslate_raw( - cls, - *args, - **params, - ): - if len(args) > 8: - raise TypeError( - f"atranslate_raw() takes from 4 to 9 positional arguments but {len(args)+1} were given" - ) - util.check_required(*args, method_name="atranslate_raw", required=["model", "file", "filename"], **params) - - positional = list(args) - model = positional.pop(0) if positional else params.pop("model", None) - file = positional.pop(0) if positional else params.pop("file", None) - filename = positional.pop(0) if positional else params.pop("filename", None) - api_key = positional.pop(0) if positional else params.pop("api_key", None) - api_base = positional.pop(0) if positional else params.pop("api_base", None) - api_type = positional.pop(0) if positional else params.pop("api_type", None) - api_version = positional.pop(0) if positional else params.pop("api_version", None) - organization = positional.pop(0) if positional else params.pop("organization", None) - deployment_id = params.pop("deployment_id", None) - requestor, files, data = cls._prepare_request( file=file, filename=filename, diff --git a/openai/tests/test_audio_overloads.py b/openai/tests/test_audio_overloads.py deleted file mode 100644 index 56bf2ebcf2..0000000000 --- a/openai/tests/test_audio_overloads.py +++ /dev/null @@ -1,705 +0,0 @@ -import openai -import pytest - - -API_BASE = "" -AZURE_API_KEY = "" -OPENAI_API_KEY = "" -API_VERSION = "" -AUDIO_FILE_PATH = "" -AUDIO_FILE_NAME = "" - - -# TRANSCRIBE ----------------------------------------------------------------------------------- -def test_transcribe(): - - # Invalid - with pytest.raises(TypeError) as e: - openai.Audio.transcribe( - "whisper-1", - open(AUDIO_FILE_PATH, "rb"), - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "transcribe() takes from 3 to 8 positional arguments but 9 were given" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe() - assert str(e.value) == "transcribe() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe( - "whisper-1" - ) - assert str(e.value) == "transcribe() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe( - model="whisper-1" - ) - assert str(e.value) == "transcribe() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe( - file=open(AUDIO_FILE_PATH, "rb") - ) - assert str(e.value) == "transcribe() missing 1 required positional argument(s): model" - - # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = openai.Audio.transcribe( - "whisper-1", - open(AUDIO_FILE_PATH, "rb") - ) - assert audio - - audio = openai.Audio.transcribe( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb") - ) - assert audio - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = openai.Audio.transcribe( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb"), - ) - assert audio - - -@pytest.mark.asyncio -async def test_atranscribe(): - - # Invalid - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe( - "whisper-1", - open(AUDIO_FILE_PATH, "rb"), - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "atranscribe() takes from 3 to 8 positional arguments but 9 were given" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe() - assert str(e.value) == "atranscribe() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe( - "whisper-1" - ) - assert str(e.value) == "atranscribe() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe( - model="whisper-1" - ) - assert str(e.value) == "atranscribe() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe( - file=open(AUDIO_FILE_PATH, "rb") - ) - assert str(e.value) == "atranscribe() missing 1 required positional argument(s): model" - - # # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = await openai.Audio.atranscribe( - "whisper-1", - open(AUDIO_FILE_PATH, "rb") - ) - assert audio - - audio1 = await openai.Audio.atranscribe( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb") - ) - assert audio1 - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = await openai.Audio.atranscribe( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb"), - ) - assert audio - - -def test_transcribe_raw(): - - # Invalid - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - "filename", - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "transcribe_raw() takes from 4 to 9 positional arguments but 10 were given" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw() - assert str(e.value) == "transcribe_raw() missing 3 required positional argument(s): model, file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - "whisper-1" - ) - assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - model="whisper-1" - ) - assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): model, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "transcribe_raw() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - model="whisper-1", - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): file" - - - with pytest.raises(TypeError) as e: - openai.Audio.transcribe_raw( - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "transcribe_raw() missing 1 required positional argument(s): model" - - - # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = openai.Audio.transcribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - AUDIO_FILE_NAME - ) - assert audio - - audio = openai.Audio.transcribe_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = openai.Audio.transcribe_raw( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - -@pytest.mark.asyncio -async def test_atranscribe_raw(): - - # Invalid - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - "filename", - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "atranscribe_raw() takes from 4 to 9 positional arguments but 10 were given" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw() - assert str(e.value) == "atranscribe_raw() missing 3 required positional argument(s): model, file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - "whisper-1" - ) - assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - model="whisper-1" - ) - assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): model, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranscribe_raw() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - model="whisper-1", - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): file" - - - with pytest.raises(TypeError) as e: - await openai.Audio.atranscribe_raw( - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranscribe_raw() missing 1 required positional argument(s): model" - - - # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = await openai.Audio.atranscribe_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - AUDIO_FILE_NAME - ) - assert audio - - audio = await openai.Audio.atranscribe_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = await openai.Audio.atranscribe_raw( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - -# TRANSLATE ----------------------------------------------------------------------------------- - -def test_translate(): - - # Invalid - with pytest.raises(TypeError) as e: - openai.Audio.translate( - "whisper-1", - open(AUDIO_FILE_PATH, "rb"), - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "translate() takes from 3 to 8 positional arguments but 9 were given" - - with pytest.raises(TypeError) as e: - openai.Audio.translate() - assert str(e.value) == "translate() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - openai.Audio.translate( - "whisper-1" - ) - assert str(e.value) == "translate() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - openai.Audio.translate( - model="whisper-1" - ) - assert str(e.value) == "translate() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - openai.Audio.translate( - file=open(AUDIO_FILE_PATH, "rb") - ) - assert str(e.value) == "translate() missing 1 required positional argument(s): model" - - # # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = openai.Audio.translate( - "whisper-1", - open(AUDIO_FILE_PATH, "rb") - ) - assert audio - - audio1 = openai.Audio.translate( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb") - ) - assert audio1 - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = openai.Audio.translate( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb"), - ) - assert audio - - -@pytest.mark.asyncio -async def test_atranslate(): - - # Invalid - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate( - "whisper-1", - open(AUDIO_FILE_PATH, "rb"), - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "atranslate() takes from 3 to 8 positional arguments but 9 were given" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate() - assert str(e.value) == "atranslate() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate( - "whisper-1" - ) - assert str(e.value) == "atranslate() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate( - model="whisper-1" - ) - assert str(e.value) == "atranslate() missing 1 required positional argument(s): file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate( - file=open(AUDIO_FILE_PATH, "rb") - ) - assert str(e.value) == "atranslate() missing 1 required positional argument(s): model" - - # # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = await openai.Audio.atranslate( - "whisper-1", - open(AUDIO_FILE_PATH, "rb") - ) - assert audio - - audio1 = await openai.Audio.atranslate( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb") - ) - assert audio1 - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = await openai.Audio.atranslate( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb"), - ) - assert audio - - -def test_translate_raw(): - - # Invalid - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - "filename", - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "translate_raw() takes from 4 to 9 positional arguments but 10 were given" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw() - assert str(e.value) == "translate_raw() missing 3 required positional argument(s): model, file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - "whisper-1" - ) - assert str(e.value) == "translate_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "translate_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - model="whisper-1" - ) - assert str(e.value) == "translate_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "translate_raw() missing 2 required positional argument(s): model, filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "translate_raw() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "translate_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - model="whisper-1", - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "translate_raw() missing 1 required positional argument(s): file" - - - with pytest.raises(TypeError) as e: - openai.Audio.translate_raw( - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "translate_raw() missing 1 required positional argument(s): model" - - - # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = openai.Audio.translate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - AUDIO_FILE_NAME - ) - assert audio - - audio = openai.Audio.translate_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = openai.Audio.translate_raw( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - -@pytest.mark.asyncio -async def test_atranslate_raw(): - - # Invalid - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - "filename", - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "extra", - ) - assert str(e.value) == "atranslate_raw() takes from 4 to 9 positional arguments but 10 were given" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw() - assert str(e.value) == "atranslate_raw() missing 3 required positional argument(s): model, file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - "whisper-1" - ) - assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - model="whisper-1" - ) - assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): file, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): model, filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranslate_raw() missing 2 required positional argument(s): model, file" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read() - ) - assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): filename" - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - model="whisper-1", - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): file" - - - with pytest.raises(TypeError) as e: - await openai.Audio.atranslate_raw( - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert str(e.value) == "atranslate_raw() missing 1 required positional argument(s): model" - - - # Valid - openai.api_base = "https://api.openai.com/v1" - openai.api_type = "openai" - openai.api_key = OPENAI_API_KEY - openai.api_version = None - audio = await openai.Audio.atranslate_raw( - "whisper-1", - open(AUDIO_FILE_PATH, "rb").read(), - AUDIO_FILE_NAME - ) - assert audio - - audio = await openai.Audio.atranslate_raw( - model="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio - - openai.api_base = API_BASE - openai.api_key = AZURE_API_KEY - openai.api_type = "azure" - openai.api_version = API_VERSION - audio = await openai.Audio.atranslate_raw( - deployment_id="whisper-1", - file=open(AUDIO_FILE_PATH, "rb").read(), - filename=AUDIO_FILE_NAME - ) - assert audio diff --git a/openai/util.py b/openai/util.py index 5d16092a1b..5501d5b67e 100644 --- a/openai/util.py +++ b/openai/util.py @@ -186,21 +186,3 @@ def default_api_key() -> str: raise openai.error.AuthenticationError( "No API key provided. You can set your API key in code using 'openai.api_key = ', or you can set the environment variable OPENAI_API_KEY=). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = '. You can generate API keys in the OpenAI web interface. See https://platform.openai.com/account/api-keys for details." ) - - -def check_required(*args, method_name, required, **kwargs): - """Checks that all required parameters have been provided - to the method where overloads are used to maintain existing behavior. - """ - missing = [] - args_count = len(args) - for param in required: - if param in kwargs: - continue - if args_count > 0: - args_count -= 1 - continue - missing.append(param) - - if missing and "deployment_id" not in kwargs: - raise TypeError(f"{method_name}() missing {len(missing)} required positional argument(s): {', '.join(missing)}")