From 56ffbbc4cd8b9922f8e2ce49102c6155ce42e7f8 Mon Sep 17 00:00:00 2001 From: w0rp Date: Tue, 22 Apr 2025 19:57:10 +0100 Subject: [PATCH] you can figure this out --- autoload/neural.vim | 10 +- autoload/neural/config.vim | 98 +++++++++---- autoload/neural/source/chatgpt.vim | 10 -- autoload/neural/source/openai.vim | 2 +- neural_providers/chatgpt.py | 204 --------------------------- neural_providers/openai.py | 106 ++++++++++---- test/python/test_chatgpt.py | 214 ----------------------------- test/python/test_openai.py | 117 +++++++++++++--- to-do | 7 + 9 files changed, 267 insertions(+), 501 deletions(-) delete mode 100644 autoload/neural/source/chatgpt.vim delete mode 100644 neural_providers/chatgpt.py delete mode 100644 test/python/test_chatgpt.py create mode 100644 to-do diff --git a/autoload/neural.vim b/autoload/neural.vim index ba06cc1..6e256a5 100644 --- a/autoload/neural.vim +++ b/autoload/neural.vim @@ -235,12 +235,13 @@ function! neural#PreProcess(buffer, input) abort endfunction function! s:LoadDataSource() abort - let l:selected = g:neural.selected + " TODO: Change message if nothing is selected. + let l:type = get(g:neural.sources, 0, {'type': ''}).type try - let l:source = function('neural#source#' . selected . '#Get')() + let l:source = function('neural#source#' . l:type . '#Get')() catch /E117/ - call neural#OutputErrorMessage('Invalid source: ' . l:selected) + call neural#OutputErrorMessage('Invalid source: ' . l:type) return endtry @@ -249,7 +250,8 @@ function! s:LoadDataSource() abort endfunction function! s:GetSourceInput(buffer, source, prompt) abort - let l:config = get(g:neural.source, a:source.name, {}) + " FIXME: Pass the config around. + let l:config = get(g:neural.sources, 0, {}) " If the config is not a Dictionary, throw it away. if type(l:config) isnot v:t_dict diff --git a/autoload/neural/config.vim b/autoload/neural/config.vim index 933ee4f..54cbf96 100644 --- a/autoload/neural/config.vim +++ b/autoload/neural/config.vim @@ -5,8 +5,22 @@ scriptencoding utf-8 " Track modifications to g:neural, in case we set it again. let s:last_dictionary = get(s:, 'last_dictionary', {}) +" TODO: Default use_chat_api value here instead of in Python? +let s:source_defaults = { +\ 'openai': { +\ 'url': 'https://api.openai.com', +\ 'api_key': '', +\ 'frequency_penalty': 0.1, +\ 'max_tokens': 1024, +\ 'model': 'gpt-3.5-turbo-instruct', +\ 'use_chat_api': v:false, +\ 'presence_penalty': 0.1, +\ 'temperature': 0.2, +\ 'top_p': 1, +\ }, +\} + let s:defaults = { -\ 'selected': 'openai', \ 'pre_process': { \ 'enabled': v:true, \ }, @@ -22,26 +36,7 @@ let s:defaults = { \ 'create_mode': 'vertical', \ 'wrap': v:true, \ }, -\ 'source': { -\ 'openai': { -\ 'api_key': '', -\ 'frequency_penalty': 0.1, -\ 'max_tokens': 1024, -\ 'model': 'gpt-3.5-turbo-instruct', -\ 'presence_penalty': 0.1, -\ 'temperature': 0.2, -\ 'top_p': 1, -\ }, -\ 'chatgpt': { -\ 'api_key': '', -\ 'frequency_penalty': 0.1, -\ 'max_tokens': 2048, -\ 'model': 'gpt-3.5-turbo', -\ 'presence_penalty': 0.1, -\ 'temperature': 0.2, -\ 'top_p': 1, -\ }, -\ }, +\ 'sources': [], \} function! neural#config#DeepMerge(into, from) abort @@ -56,28 +51,75 @@ function! neural#config#DeepMerge(into, from) abort return a:into endfunction -function! s:ApplySpecialDefaults() abort - if empty(g:neural.source.chatgpt.api_key) - let g:neural.source.chatgpt.api_key = g:neural.source.openai.api_key - endif -endfunction - " Set the shared configuration for Neural. function! neural#config#Set(settings) abort let g:neural = a:settings call neural#config#Load() endfunction +function! neural#config#ConvertLegacySettings(dictionary) abort + " Replace 'source' with newer 'sources' + if has_key(a:dictionary, 'source') && !has_key(a:dictionary, 'sources') + let l:source = remove(a:dictionary, 'source') + let a:dictionary.sources = [] + + if type(l:source) is v:t_dict + " Keep old behavior to default the chatgpt key to the openai key. + let l:default_api_key = get(get(l:source, 'openai', {}), 'api_key', '') + + for [l:type, l:settings] in items(l:source) + let l:settings = copy(l:settings) + let l:settings.use_chat_api = l:type is# 'chatgpt' ? v:true : v:false + let l:settings.type = l:type is# 'chatgpt' ? 'openai' : l:type + + if empty(get(l:settings, 'api_key')) + let l:settings.api_key = l:default_api_key + endif + + call add(a:dictionary.sources, l:settings) + endfor + endif + endif + + " Remove the 'selected' key if set. + if has_key(a:dictionary, 'selected') + call remove(a:dictionary, 'selected') + endif +endfunction + +function! neural#config#MergeSourceDefaults(sources) abort + let l:merged_sources = [] + + if type(a:sources) is v:t_list + for l:source in a:sources + let l:type = get(l:source, 'type', v:null) + + call add(l:merged_sources, neural#config#DeepMerge( + \ deepcopy(get(s:source_defaults, l:type, {})), + \ l:source, + \)) + endfor + endif + + return l:merged_sources +endfunction + function! neural#config#Load() abort let l:dictionary = get(g:, 'neural', {}) " Merge the Dictionary with defaults again if g:neural changed. if l:dictionary isnot# s:last_dictionary + " Create a shallow copy to modify + let l:dictionary = copy(l:dictionary) + call neural#config#ConvertLegacySettings(l:dictionary) + let l:dictionary.sources = neural#config#MergeSourceDefaults( + \ get(l:dictionary, 'sources', v:null) + \) + let s:last_dictionary = neural#config#DeepMerge( \ deepcopy(s:defaults), \ l:dictionary, \) let g:neural = s:last_dictionary - call s:ApplySpecialDefaults() endif endfunction diff --git a/autoload/neural/source/chatgpt.vim b/autoload/neural/source/chatgpt.vim deleted file mode 100644 index 27735f8..0000000 --- a/autoload/neural/source/chatgpt.vim +++ /dev/null @@ -1,10 +0,0 @@ -" Author: w0rp -" Description: A script describing how to use ChatGPT with Neural - -function! neural#source#chatgpt#Get() abort - return { - \ 'name': 'chatgpt', - \ 'script_language': 'python', - \ 'script': neural#GetScriptDir() . '/chatgpt.py', - \} -endfunction diff --git a/autoload/neural/source/openai.vim b/autoload/neural/source/openai.vim index ff81b9c..1adb6d6 100644 --- a/autoload/neural/source/openai.vim +++ b/autoload/neural/source/openai.vim @@ -1,5 +1,5 @@ " Author: w0rp -" Description: A script describing how to use OpenAI with Neural +" Description: A script describing how to use OpenAI compatible APIs with Neural function! neural#source#openai#Get() abort return { diff --git a/neural_providers/chatgpt.py b/neural_providers/chatgpt.py deleted file mode 100644 index e981c62..0000000 --- a/neural_providers/chatgpt.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -A Neural datasource for ChatGPT conversations. -""" -import json -import platform -import ssl -import sys -import urllib.error -import urllib.request -from typing import Any, Dict, List, Optional, Union - -API_ENDPOINT = 'https://api.openai.com/v1/chat/completions' - -OPENAI_DATA_HEADER = 'data: ' -OPENAI_DONE = '[DONE]' - - -class Config: - """ - The sanitised configuration. - """ - def __init__( - self, - api_key: str, - model: str, - temperature: float, - top_p: float, - max_tokens: int, - presence_penalty: float, - frequency_penalty: float, - ): - self.api_key = api_key - self.model = model - self.temperature = temperature - self.top_p = top_p - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - - -def get_chatgpt_completion( - config: Config, - prompt: Union[str, List[Dict[str, str]]], -) -> None: - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {config.api_key}" - } - data = { - "model": config.model, - "messages": ( - [{"role": "user", "content": prompt}] - if isinstance(prompt, str) else - prompt - ), - "temperature": config.temperature, - "max_tokens": config.max_tokens, - "top_p": 1, - "presence_penalty": config.presence_penalty, - "frequency_penalty": config.frequency_penalty, - "stream": True, - } - - req = urllib.request.Request( - API_ENDPOINT, - data=json.dumps(data).encode("utf-8"), - headers=headers, - method="POST", - ) - role: Optional[str] = None - - # Disable SSL certificate verification on macOS. - # This is bad for security, and we need to deal with SSL errors better. - # - # This is the error: - # urllib.error.URLError: # noqa - context = ( - ssl._create_unverified_context() # type: ignore - if platform.system() == "Darwin" else - None - ) - - with urllib.request.urlopen(req, context=context) as response: - while True: - line_bytes = response.readline() - - if not line_bytes: - break - - line = line_bytes.decode("utf-8", errors="replace") - line_data = ( - line[len(OPENAI_DATA_HEADER):-1] - if line.startswith(OPENAI_DATA_HEADER) else - None - ) - - if line_data and line_data != OPENAI_DONE: - delta = json.loads(line_data)["choices"][0]["delta"] - # The role is typically in the first delta only. - role = delta.get("role", role) - - if role == "assistant" and "content" in delta: - print(delta["content"], end="", flush=True) - - print() - - -def load_config(raw_config: Dict[str, Any]) -> Config: - # TODO: Add range validation for request parameters. - if not isinstance(raw_config, dict): # type: ignore - raise ValueError("chatgpt config is not a dictionary") - - api_key = raw_config.get('api_key') - - if not isinstance(api_key, str) or not api_key: # type: ignore - raise ValueError("chatgpt.api_key is not defined") - - model = raw_config.get('model') - - if not isinstance(model, str) or not model: - raise ValueError("chatgpt.model is not defined") - - temperature = raw_config.get('temperature', 0.2) - - if not isinstance(temperature, (int, float)): - raise ValueError("chatgpt.temperature is invalid") - - top_p = raw_config.get('top_p', 1) - - if not isinstance(top_p, (int, float)): - raise ValueError("chatgpt.top_p is invalid") - - max_tokens = raw_config.get('max_tokens', 1024) - - if not isinstance(max_tokens, (int)): - raise ValueError("chatgpt.max_tokens is invalid") - - presence_penalty = raw_config.get('presence_penalty', 0) - - if not isinstance(presence_penalty, (int, float)): - raise ValueError("chatgpt.presence_penalty is invalid") - - frequency_penalty = raw_config.get('frequency_penalty', 0) - - if not isinstance(frequency_penalty, (int, float)): - raise ValueError("chatgpt.frequency_penalty is invalid") - - return Config( - api_key=api_key, - model=model, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=presence_penalty, - ) - - -def get_error_message(error: urllib.error.HTTPError) -> str: - message = error.read().decode('utf-8', errors='ignore') - - try: - # JSON data might look like this: - # { - # "error": { - # "message": "...", - # "type": "...", - # "param": null, - # "code": null - # } - # } - message = json.loads(message)['error']['message'] - - if "This model's maximum context length is" in message: - message = 'Too much text for a request!' - except Exception: - # If we can't get a better message use the JSON payload at least. - pass - - return message - - -def main() -> None: - input_data = json.loads(sys.stdin.readline()) - - try: - config = load_config(input_data["config"]) - except ValueError as err: - sys.exit(str(err)) - - try: - get_chatgpt_completion(config, input_data["prompt"]) - except urllib.error.HTTPError as error: - if error.code == 400 or error.code == 401: - message = get_error_message(error) - sys.exit('Neural error: OpenAI request failure: ' + message) - elif error.code == 429: - sys.exit("Neural error: OpenAI request limit reached!") - else: - raise - - -if __name__ == "__main__": # pragma: no cover - main() # pragma: no cover diff --git a/neural_providers/openai.py b/neural_providers/openai.py index 7cf3bff..e6cd0f7 100644 --- a/neural_providers/openai.py +++ b/neural_providers/openai.py @@ -1,5 +1,5 @@ """ -A Neural datasource for loading generated text via OpenAI. +A Neural datasource for GPT conversations. """ import json import platform @@ -7,9 +7,7 @@ import sys import urllib.error import urllib.request -from typing import Any, Dict - -API_ENDPOINT = 'https://api.openai.com/v1/completions' +from typing import Any, Dict, List, Optional, Union OPENAI_DATA_HEADER = 'data: ' OPENAI_DONE = '[DONE]' @@ -21,16 +19,20 @@ class Config: """ def __init__( self, + url: str, api_key: str, model: str, + use_chat_api: bool, temperature: float, top_p: float, max_tokens: int, presence_penalty: float, frequency_penalty: float, ): + self.url = url self.api_key = api_key self.model = model + self.use_chat_api = use_chat_api self.temperature = temperature self.top_p = top_p self.max_tokens = max_tokens @@ -38,14 +40,16 @@ def __init__( self.frequency_penalty = frequency_penalty -def get_openai_completion(config: Config, prompt: str) -> None: +def get_openai_completion( + config: Config, + prompt: Union[str, List[Dict[str, str]]], +) -> None: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {config.api_key}" } - data = { + data: Dict[str, Any] = { "model": config.model, - "prompt": prompt, "temperature": config.temperature, "max_tokens": config.max_tokens, "top_p": 1, @@ -54,13 +58,26 @@ def get_openai_completion(config: Config, prompt: str) -> None: "stream": True, } + if config.use_chat_api: + data["messages"] = ( + [{"role": "user", "content": prompt}] + if isinstance(prompt, str) else + prompt + ) + else: + data["prompt"] = prompt + req = urllib.request.Request( - API_ENDPOINT, + ( + f'{config.url}/v1/chat/completions' + if config.use_chat_api else + f'{config.url}/v1/completions' + ), data=json.dumps(data).encode("utf-8"), headers=headers, method="POST", - unverifiable=True, ) + role: Optional[str] = None # Disable SSL certificate verification on macOS. # This is bad for security, and we need to deal with SSL errors better. @@ -81,15 +98,23 @@ def get_openai_completion(config: Config, prompt: str) -> None: break line = line_bytes.decode("utf-8", errors="replace") - - if line.startswith(OPENAI_DATA_HEADER): - line_data = line[len(OPENAI_DATA_HEADER):-1] - - if line_data == OPENAI_DONE: - pass + line_data = ( + line[len(OPENAI_DATA_HEADER):-1] + if line.startswith(OPENAI_DATA_HEADER) else + None + ) + + if line_data and line_data != OPENAI_DONE: + openai_obj = json.loads(line_data) + + if config.use_chat_api: + delta = openai_obj["choices"][0]["delta"] + # The role is typically in the first delta only. + role = delta.get("role", role) + + if role == "assistant" and "content" in delta: + print(delta["content"], end="", flush=True) else: - openai_obj = json.loads(line_data) - print(openai_obj["choices"][0]["text"], end="", flush=True) print() @@ -100,44 +125,74 @@ def load_config(raw_config: Dict[str, Any]) -> Config: if not isinstance(raw_config, dict): # type: ignore raise ValueError("openai config is not a dictionary") + url = raw_config.get('url') + + if url is None: + url = 'https://api.openai.com' + elif not isinstance(url, str): + raise ValueError("url must be a string") + elif not url.startswith("http://") and not url.startswith("https://"): + raise ValueError("url must start with http(s)://") + api_key = raw_config.get('api_key') if not isinstance(api_key, str) or not api_key: # type: ignore - raise ValueError("openai.api_key is not defined") + raise ValueError("api_key is not defined") model = raw_config.get('model') if not isinstance(model, str) or not model: - raise ValueError("openai.model is not defined") + raise ValueError("model is not defined") + + use_chat_api = raw_config.get('use_chat_api') + + if use_chat_api is None: + # Default to the older completions API if using certain older models. + use_chat_api = model not in ( + 'ada', + 'babbage', + 'curie', + 'davinci', + 'gpt-3.5-turbo-instruct', + 'text-ada-001', + 'text-babbage-001', + 'text-curie-001', + 'text-davinci-002', + 'text-davinci-003', + ) + elif not isinstance(use_chat_api, bool): + raise ValueError("use_chat_api must be true or false") temperature = raw_config.get('temperature', 0.2) if not isinstance(temperature, (int, float)): - raise ValueError("openai.temperature is invalid") + raise ValueError("temperature is invalid") top_p = raw_config.get('top_p', 1) if not isinstance(top_p, (int, float)): - raise ValueError("openai.top_p is invalid") + raise ValueError("top_p is invalid") max_tokens = raw_config.get('max_tokens', 1024) if not isinstance(max_tokens, (int)): - raise ValueError("openai.max_tokens is invalid") + raise ValueError("max_tokens is invalid") presence_penalty = raw_config.get('presence_penalty', 0) if not isinstance(presence_penalty, (int, float)): - raise ValueError("openai.presence_penalty is invalid") + raise ValueError("presence_penalty is invalid") frequency_penalty = raw_config.get('frequency_penalty', 0) if not isinstance(frequency_penalty, (int, float)): - raise ValueError("openai.frequency_penalty is invalid") + raise ValueError("frequency_penalty is invalid") return Config( + url=url, api_key=api_key, model=model, + use_chat_api=use_chat_api, temperature=temperature, top_p=top_p, max_tokens=max_tokens, @@ -184,6 +239,9 @@ def main() -> None: if error.code == 400 or error.code == 401: message = get_error_message(error) sys.exit('Neural error: OpenAI request failure: ' + message) + if error.code == 404: + message = get_error_message(error) + sys.exit('Neural error: OpenAI request failure: ' + message) elif error.code == 429: sys.exit("Neural error: OpenAI request limit reached!") else: diff --git a/test/python/test_chatgpt.py b/test/python/test_chatgpt.py deleted file mode 100644 index 3fff5b5..0000000 --- a/test/python/test_chatgpt.py +++ /dev/null @@ -1,214 +0,0 @@ -import json -import sys -import urllib.error -import urllib.request -from io import BytesIO -from typing import Any, Dict, Optional, cast -from unittest import mock - -import pytest - -from neural_providers import chatgpt - - -def get_valid_config() -> Dict[str, Any]: - return { - "api_key": ".", - "model": "foo", - "prompt": "say hello", - "temperature": 1, - "top_p": 1, - "max_tokens": 1, - "presence_penalty": 1, - "frequency_penalty": 1, - } - - -def test_load_config_errors(): - with pytest.raises(ValueError) as exc: - chatgpt.load_config(cast(Any, 0)) - - assert str(exc.value) == "chatgpt config is not a dictionary" - - config: Dict[str, Any] = {} - - for modification, expected_error in [ - ({}, "chatgpt.api_key is not defined"), - ({"api_key": ""}, "chatgpt.api_key is not defined"), - ({"api_key": "."}, "chatgpt.model is not defined"), - ({"model": ""}, "chatgpt.model is not defined"), - ( - {"model": "x", "temperature": "x"}, - "chatgpt.temperature is invalid" - ), - ( - {"temperature": 1, "top_p": "x"}, - "chatgpt.top_p is invalid" - ), - ( - {"top_p": 1, "max_tokens": "x"}, - "chatgpt.max_tokens is invalid" - ), - ( - {"max_tokens": 1, "presence_penalty": "x"}, - "chatgpt.presence_penalty is invalid" - ), - ( - {"presence_penalty": 1, "frequency_penalty": "x"}, - "chatgpt.frequency_penalty is invalid" - ), - ]: - config.update(modification) - - with pytest.raises(ValueError) as exc: - chatgpt.load_config(config) - - assert str(exc.value) == expected_error, config - - -def test_main_function_rate_other_error(): - with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ - mock.patch.object(chatgpt, 'get_chatgpt_completion') as compl_mock: - - compl_mock.side_effect = urllib.error.HTTPError( - url='', - msg='', - hdrs=mock.Mock(), - fp=None, - code=500, - ) - readline_mock.return_value = json.dumps({ - "config": get_valid_config(), - "prompt": "hello there", - }) - - with pytest.raises(urllib.error.HTTPError): - chatgpt.main() - - -def test_print_chatgpt_results(): - result_data = ( - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"\\n\\n"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"This"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" is"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" a"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" test"},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"."},"index":0,"finish_reason":null}]}\n' # noqa - b'\n' - b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"length"}]}\n' # noqa - b'\n' - b'data: [DONE]\n' - b'\n' - ) - - with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ - mock.patch.object(urllib.request, 'urlopen') as urlopen_mock, \ - mock.patch('builtins.print') as print_mock: - - urlopen_mock.return_value.__enter__.return_value = BytesIO(result_data) - - readline_mock.return_value = json.dumps({ - "config": get_valid_config(), - "prompt": "Say this is a test", - }) - chatgpt.main() - - assert print_mock.call_args_list == [ - mock.call('\n\n', end='', flush=True), - mock.call('This', end='', flush=True), - mock.call(' is', end='', flush=True), - mock.call(' a', end='', flush=True), - mock.call(' test', end='', flush=True), - mock.call('.', end='', flush=True), - mock.call(), - ] - - -def test_main_function_bad_config(): - with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ - mock.patch.object(chatgpt, 'load_config') as load_config_mock: - - load_config_mock.side_effect = ValueError("expect this") - readline_mock.return_value = json.dumps({"config": {}}) - - with pytest.raises(SystemExit) as exc: - chatgpt.main() - - assert str(exc.value) == 'expect this' - - -@pytest.mark.parametrize( - 'code, error_text, expected_message', - ( - pytest.param( - 429, - None, - 'OpenAI request limit reached!', - id="request_limit", - ), - pytest.param( - 400, - '{]', - 'OpenAI request failure: {]', - id="error_with_mangled_json", - ), - pytest.param( - 400, - json.dumps({'error': {}}), - 'OpenAI request failure: {"error": {}}', - id="error_with_missing_message_key", - ), - pytest.param( - 400, - json.dumps({ - 'error': { - 'message': "This model's maximum context length is 123", - }, - }), - 'OpenAI request failure: Too much text for a request!', - id="too_much_text", - ), - pytest.param( - 401, - json.dumps({ - 'error': { - 'message': "Bad authentication error", - }, - }), - 'OpenAI request failure: Bad authentication error', - id="unauthorised_failure", - ), - ) -) -def test_api_error( - code: int, - error_text: Optional[str], - expected_message: str, -): - with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ - mock.patch.object(chatgpt, 'get_chatgpt_completion') as compl_mock: - - compl_mock.side_effect = urllib.error.HTTPError( - url='', - msg='', - hdrs=mock.Mock(), - fp=BytesIO(error_text.encode('utf-8')) if error_text else None, - code=code, - ) - - readline_mock.return_value = json.dumps({ - "config": get_valid_config(), - "prompt": "hello there", - }) - - with pytest.raises(SystemExit) as exc: - chatgpt.main() - - assert str(exc.value) == f'Neural error: {expected_message}' diff --git a/test/python/test_openai.py b/test/python/test_openai.py index 24d4c10..4110821 100644 --- a/test/python/test_openai.py +++ b/test/python/test_openai.py @@ -11,10 +11,10 @@ from neural_providers import openai -def get_valid_config() -> Dict[str, Any]: +def get_valid_config(model: str = "foo") -> Dict[str, Any]: return { "api_key": ".", - "model": "foo", + "model": model, "prompt": "say hello", "temperature": 1, "top_p": 1, @@ -33,29 +33,34 @@ def test_load_config_errors(): config: Dict[str, Any] = {} for modification, expected_error in [ - ({}, "openai.api_key is not defined"), - ({"api_key": ""}, "openai.api_key is not defined"), - ({"api_key": "."}, "openai.model is not defined"), - ({"model": ""}, "openai.model is not defined"), + ({"url": 1}, "url must be a string"), + ({"url": "x"}, "url must start with http(s)://"), + ({"url": "https://x", "api_key": ""}, "api_key is not defined"), + ({"api_key": "."}, "model is not defined"), + ({"model": ""}, "model is not defined"), ( - {"model": "x", "temperature": "x"}, - "openai.temperature is invalid" + {"model": "x", "use_chat_api": 1}, + "use_chat_api must be true or false" + ), + ( + {"use_chat_api": None, "temperature": "x"}, + "temperature is invalid" ), ( {"temperature": 1, "top_p": "x"}, - "openai.top_p is invalid" + "top_p is invalid" ), ( {"top_p": 1, "max_tokens": "x"}, - "openai.max_tokens is invalid" + "max_tokens is invalid" ), ( {"max_tokens": 1, "presence_penalty": "x"}, - "openai.presence_penalty is invalid" + "presence_penalty is invalid" ), ( {"presence_penalty": 1, "frequency_penalty": "x"}, - "openai.frequency_penalty is invalid" + "frequency_penalty is invalid" ), ]: config.update(modification) @@ -66,11 +71,46 @@ def test_load_config_errors(): assert str(exc.value) == expected_error, config +def test_automatic_completions_api_usage(): + raw_config = get_valid_config() + + for model in ( + 'ada', + 'babbage', + 'curie', + 'davinci', + 'gpt-3.5-turbo-instruct', + 'text-ada-001', + 'text-babbage-001', + 'text-curie-001', + 'text-davinci-002', + 'text-davinci-003', + ): + raw_config['model'] = model + + assert openai.load_config(raw_config).use_chat_api is False + + for model in ('gpt-3.5', 'gpt-4'): + raw_config['model'] = model + + assert openai.load_config(raw_config).use_chat_api is True + + +def test_url_configuration(): + raw_config = get_valid_config() + + assert openai.load_config(raw_config).url == 'https://api.openai.com' + + raw_config['url'] = 'http://myhost' + + assert openai.load_config(raw_config).url == 'http://myhost' + + def test_main_function_rate_other_error(): with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ - mock.patch.object(openai, 'get_openai_completion') as completion_mock: + mock.patch.object(openai, 'get_openai_completion') as compl_mock: - completion_mock.side_effect = urllib.error.HTTPError( + compl_mock.side_effect = urllib.error.HTTPError( url='', msg='', hdrs=mock.Mock(), @@ -86,7 +126,7 @@ def test_main_function_rate_other_error(): openai.main() -def test_print_openai_results(): +def test_print_openai_completion_results(): result_data = ( b'data: {"id": "cmpl-6jMlRJtbYTGrNwE6Lxy1Ns1EtD0is", "object": "text_completion", "created": 1676270285, "choices": [{"text": "\\n", "index": 0, "logprobs": null, "finish_reason": null}], "model": "gpt-3.5-turbo-instruct"}\n' # noqa b'\n' @@ -109,7 +149,7 @@ def test_print_openai_results(): urlopen_mock.return_value.__enter__.return_value = BytesIO(result_data) readline_mock.return_value = json.dumps({ - "config": get_valid_config(), + "config": get_valid_config("gpt-3.5-turbo-instruct"), "prompt": "hello there", }) openai.main() @@ -124,6 +164,51 @@ def test_print_openai_results(): ] +def test_print_openai_chat_completion_results(): + result_data = ( + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"\\n\\n"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"This"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" is"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" a"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":" test"},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"."},"index":0,"finish_reason":null}]}\n' # noqa + b'\n' + b'data: {"id":"chatcmpl-6tMwjovREOTA84MkGBOS5rWyj1izv","object":"chat.completion.chunk","created":1678654265,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"length"}]}\n' # noqa + b'\n' + b'data: [DONE]\n' + b'\n' + ) + + with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ + mock.patch.object(urllib.request, 'urlopen') as urlopen_mock, \ + mock.patch('builtins.print') as print_mock: + + urlopen_mock.return_value.__enter__.return_value = BytesIO(result_data) + + readline_mock.return_value = json.dumps({ + "config": get_valid_config("gpt-3.5-turbo-0301"), + "prompt": "Say this is a test", + }) + openai.main() + + assert print_mock.call_args_list == [ + mock.call('\n\n', end='', flush=True), + mock.call('This', end='', flush=True), + mock.call(' is', end='', flush=True), + mock.call(' a', end='', flush=True), + mock.call(' test', end='', flush=True), + mock.call('.', end='', flush=True), + mock.call(), + ] + + def test_main_function_bad_config(): with mock.patch.object(sys.stdin, 'readline') as readline_mock, \ mock.patch.object(openai, 'load_config') as load_config_mock: diff --git a/to-do b/to-do new file mode 100644 index 0000000..7827bfc --- /dev/null +++ b/to-do @@ -0,0 +1,7 @@ +* Rename autoload source file to provider +* Print a nice message when you have no sources +* Pass the config for the selected source across +* Update tests for configuration +* Update other tests? +* Make the chat thing look for the first chat-capable source +* Move the use_chat_api default from Python to Vim script