diff --git a/pylsp_jsonrpc/endpoint.py b/pylsp_jsonrpc/endpoint.py index 9bbe4c8..4c0b7c2 100644 --- a/pylsp_jsonrpc/endpoint.py +++ b/pylsp_jsonrpc/endpoint.py @@ -4,6 +4,7 @@ import logging import uuid import sys +import asyncio from concurrent import futures from .exceptions import (JsonRpcException, JsonRpcRequestCancelled, @@ -12,6 +13,7 @@ log = logging.getLogger(__name__) JSONRPC_VERSION = '2.0' CANCEL_METHOD = '$/cancelRequest' +EXIT_METHOD = 'exit' class Endpoint: @@ -35,9 +37,24 @@ def __init__(self, dispatcher, consumer, id_generator=lambda: str(uuid.uuid4()), self._client_request_futures = {} self._server_request_futures = {} self._executor_service = futures.ThreadPoolExecutor(max_workers=max_workers) + self._cancelledRequests = set() + self._messageQueue = None + self._consume_task = None + + def init_async(self): + self._messageQueue = asyncio.Queue() + self._consume_task = asyncio.create_task(self.consume_task()) + + async def consume_task(self): + while self._consume_task is not None and not self._consume_task.cancelled(): + message = await self._messageQueue.get() + await asyncio.to_thread(self.consume, message) + self._messageQueue.task_done() def shutdown(self): self._executor_service.shutdown() + if self._consume_task is not None: + self._consume_task.cancel() def notify(self, method, params=None): """Send a JSON RPC notification to the client. @@ -94,6 +111,21 @@ def callback(future): future.set_exception(JsonRpcRequestCancelled()) return callback + async def consume_async(self, message): + """Consume a JSON RPC message from the client and put it into a queue. + + Args: + message (dict): The JSON RPC message sent by the client + """ + if message['method'] == CANCEL_METHOD: + self._cancelledRequests.add(message.get('params')['id']) + + # The exit message needs to be handled directly since the stream cannot be closed asynchronously + if message['method'] == EXIT_METHOD: + self.consume(message) + else: + await self._messageQueue.put(message) + def consume(self, message): """Consume a JSON RPC message from the client. @@ -182,6 +214,9 @@ def _handle_request(self, msg_id, method, params): except KeyError as e: raise JsonRpcMethodNotFound.of(method) from e + if msg_id in self._cancelledRequests: + raise JsonRpcRequestCancelled() + handler_result = handler(params) if callable(handler_result): diff --git a/pylsp_jsonrpc/streams.py b/pylsp_jsonrpc/streams.py index 40048a9..ffa3865 100644 --- a/pylsp_jsonrpc/streams.py +++ b/pylsp_jsonrpc/streams.py @@ -3,6 +3,7 @@ import logging import threading +import asyncio try: import ujson as json @@ -65,6 +66,30 @@ def _read_message(self): # Grab the body return self._rfile.read(content_length) + async def listen_async(self, message_consumer): + """Blocking call to listen for messages on the rfile. + + Args: + message_consumer (fn): function that is passed each message as it is read off the socket. + """ + + while not self._rfile.closed: + try: + request_str = await asyncio.to_thread(self._read_message) + except ValueError: + if self._rfile.closed: + return + log.exception("Failed to read from rfile") + + if request_str is None: + break + + try: + await message_consumer(json.loads(request_str.decode('utf-8'))) + except ValueError: + log.exception("Failed to parse JSON message %s", request_str) + continue + @staticmethod def _content_length(line): """Extract the content length from an input line.""" diff --git a/test/test_endpoint.py b/test/test_endpoint.py index 08fb62d..8964ed8 100644 --- a/test/test_endpoint.py +++ b/test/test_endpoint.py @@ -10,7 +10,7 @@ from pylsp_jsonrpc.endpoint import Endpoint MSG_ID = 'id' - +EXIT_METHOD = 'exit' @pytest.fixture() def dispatcher(): @@ -319,6 +319,60 @@ def test_consume_request_cancel_unknown(endpoint): }) +@pytest.mark.asyncio +async def test_consume_async_request_cancel(endpoint, dispatcher, consumer): + def async_handler(): + time.sleep(1) + handler = mock.Mock(return_value=async_handler) + dispatcher['methodName'] = handler + + endpoint.init_async() + + await endpoint.consume_async({ + 'jsonrpc': '2.0', + 'method': 'methodName', + 'params': {'key': 'value'} + }) + await endpoint.consume_async({ + 'jsonrpc': '2.0', + 'id': MSG_ID, + 'method': 'methodName', + 'params': {'key': 'value'} + }) + await endpoint.consume_async({ + 'jsonrpc': '2.0', + 'method': '$/cancelRequest', + 'params': {'id': MSG_ID} + }) + + await endpoint._messageQueue.join() + + consumer.assert_called_once_with({ + 'jsonrpc': '2.0', + 'id': MSG_ID, + 'error': exceptions.JsonRpcRequestCancelled().to_dict() + }) + + endpoint.shutdown() + + +@pytest.mark.asyncio +async def test_consume_async_exit(endpoint, dispatcher, consumer): + # verify that exit is still called synchronously + handler = mock.Mock() + dispatcher[EXIT_METHOD] = handler + + endpoint.init_async() + + await endpoint.consume_async({ + 'jsonrpc': '2.0', + 'method': EXIT_METHOD + }) + + handler.assert_called_once_with(None) + + endpoint.shutdown() + def assert_consumer_error(consumer_mock, exception): """Assert that the consumer mock has had once call with the given error message and code. diff --git a/test/test_streams.py b/test/test_streams.py index 8ded7fe..4358155 100644 --- a/test/test_streams.py +++ b/test/test_streams.py @@ -76,6 +76,53 @@ def test_reader_bad_json(rfile, reader): consumer.assert_not_called() +@pytest.mark.asyncio +async def test_reader_async(rfile, reader): + rfile.write( + b'Content-Length: 49\r\n' + b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' + b'\r\n' + b'{"id": "hello", "method": "method", "params": {}}' + ) + rfile.seek(0) + + consumer = mock.AsyncMock() + await reader.listen_async(consumer) + + consumer.assert_called_once_with({ + 'id': 'hello', + 'method': 'method', + 'params': {} + }) + + +@pytest.mark.asyncio +async def test_reader_bad_message_async(rfile, reader): + rfile.write(b'Hello world') + rfile.seek(0) + + # Ensure the listener doesn't throw + consumer = mock.AsyncMock() + await reader.listen_async(consumer) + consumer.assert_not_called() + + +@pytest.mark.asyncio +async def test_reader_bad_json_async(rfile, reader): + rfile.write( + b'Content-Length: 8\r\n' + b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' + b'\r\n' + b'{hello}}' + ) + rfile.seek(0) + + # Ensure the listener doesn't throw + consumer = mock.AsyncMock() + await reader.listen_async(consumer) + consumer.assert_not_called() + + def test_writer(wfile, writer): writer.write({ 'id': 'hello', @@ -124,5 +171,9 @@ def test_writer_bad_message(wfile, writer): b'Content-Length: 10\r\n' b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' b'\r\n' - b'1546322461' + b'1546322461', + b'Content-Length: 10\r\n' + b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' + b'\r\n' + b'1546300861' ]