diff --git a/Cargo.lock b/Cargo.lock index 68ee4c3f..1f3e9680 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2488,6 +2488,7 @@ dependencies = [ "serde_json", "text-size", "tokio", + "tower", "tower-lsp", "tracing", ] @@ -4023,8 +4024,10 @@ dependencies = [ "futures-util", "pin-project", "pin-project-lite", + "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] diff --git a/crates/pglt_lsp/Cargo.toml b/crates/pglt_lsp/Cargo.toml index e8e6dc24..a1f440a1 100644 --- a/crates/pglt_lsp/Cargo.toml +++ b/crates/pglt_lsp/Cargo.toml @@ -33,6 +33,9 @@ tower-lsp = { version = "0.20.0" } tracing = { workspace = true, features = ["attributes"] } [dev-dependencies] +tokio = { workspace = true, features = ["macros"] } +tower = { version = "0.4.13", features = ["timeout"] } + [lib] doctest = false diff --git a/crates/pglt_lsp/tests/server.rs b/crates/pglt_lsp/tests/server.rs new file mode 100644 index 00000000..990df566 --- /dev/null +++ b/crates/pglt_lsp/tests/server.rs @@ -0,0 +1,316 @@ +use anyhow::bail; +use anyhow::Context; +use anyhow::Error; +use anyhow::Result; +use futures::channel::mpsc::{channel, Sender}; +use futures::Sink; +use futures::SinkExt; +use futures::Stream; +use futures::StreamExt; +use pglt_lsp::LSPServer; +use pglt_lsp::ServerFactory; +use serde::de::DeserializeOwned; +use serde::Serialize; +use serde_json::{from_value, to_value}; +use std::any::type_name; +use std::fmt::Display; +use std::time::Duration; +use tower::timeout::Timeout; +use tower::{Service, ServiceExt}; +use tower_lsp::jsonrpc; +use tower_lsp::jsonrpc::Response; +use tower_lsp::lsp_types as lsp; +use tower_lsp::lsp_types::{ + ClientCapabilities, DidChangeConfigurationParams, DidChangeTextDocumentParams, + DidCloseTextDocumentParams, DidOpenTextDocumentParams, InitializeResult, InitializedParams, + PublishDiagnosticsParams, TextDocumentContentChangeEvent, TextDocumentIdentifier, + TextDocumentItem, Url, VersionedTextDocumentIdentifier, +}; +use tower_lsp::LspService; +use tower_lsp::{jsonrpc::Request, lsp_types::InitializeParams}; + +/// Statically build an [Url] instance that points to the file at `$path` +/// within the workspace. The filesystem path contained in the return URI is +/// guaranteed to be a valid path for the underlying operating system, but +/// doesn't have to refer to an existing file on the host machine. +macro_rules! url { + ($path:literal) => { + if cfg!(windows) { + lsp::Url::parse(concat!("file:///z%3A/workspace/", $path)).unwrap() + } else { + lsp::Url::parse(concat!("file:///workspace/", $path)).unwrap() + } + }; +} + +struct Server { + service: Timeout>, +} + +impl Server { + fn new(service: LspService) -> Self { + Self { + service: Timeout::new(service, Duration::from_secs(1)), + } + } + + async fn notify

(&mut self, method: &'static str, params: P) -> Result<()> + where + P: Serialize, + { + self.service + .ready() + .await + .map_err(Error::msg) + .context("ready() returned an error")? + .call( + Request::build(method) + .params(to_value(¶ms).context("failed to serialize params")?) + .finish(), + ) + .await + .map_err(Error::msg) + .context("call() returned an error") + .and_then(|res| { + if let Some(res) = res { + bail!("shutdown returned {:?}", res) + } else { + Ok(()) + } + }) + } + + async fn request( + &mut self, + method: &'static str, + id: &'static str, + params: P, + ) -> Result> + where + P: Serialize, + R: DeserializeOwned, + { + self.service + .ready() + .await + .map_err(Error::msg) + .context("ready() returned an error")? + .call( + Request::build(method) + .id(id) + .params(to_value(¶ms).context("failed to serialize params")?) + .finish(), + ) + .await + .map_err(Error::msg) + .context("call() returned an error")? + .map(|res| { + let (_, body) = res.into_parts(); + + let body = + body.with_context(|| format!("response to {method:?} contained an error"))?; + + from_value(body.clone()).with_context(|| { + format!( + "failed to deserialize type {} from response {body:?}", + type_name::() + ) + }) + }) + .transpose() + } + + /// Basic implementation of the `initialize` request for tests + // The `root_path` field is deprecated, but we still need to specify it + #[allow(deprecated)] + async fn initialize(&mut self) -> Result<()> { + let _res: InitializeResult = self + .request( + "initialize", + "_init", + InitializeParams { + process_id: None, + root_path: None, + root_uri: Some(url!("")), + initialization_options: None, + capabilities: ClientCapabilities::default(), + trace: None, + workspace_folders: None, + client_info: None, + locale: None, + }, + ) + .await? + .context("initialize returned None")?; + + Ok(()) + } + + /// Basic implementation of the `initialized` notification for tests + async fn initialized(&mut self) -> Result<()> { + self.notify("initialized", InitializedParams {}).await + } + + /// Basic implementation of the `shutdown` notification for tests + async fn shutdown(&mut self) -> Result<()> { + self.service + .ready() + .await + .map_err(Error::msg) + .context("ready() returned an error")? + .call(Request::build("shutdown").finish()) + .await + .map_err(Error::msg) + .context("call() returned an error") + .and_then(|res| { + if let Some(res) = res { + bail!("shutdown returned {:?}", res) + } else { + Ok(()) + } + }) + } + + async fn open_document(&mut self, text: impl Display) -> Result<()> { + self.notify( + "textDocument/didOpen", + DidOpenTextDocumentParams { + text_document: TextDocumentItem { + uri: url!("document.sql"), + language_id: String::from("sql"), + version: 0, + text: text.to_string(), + }, + }, + ) + .await + } + + /// Opens a document with given contents and given name. The name must contain the extension too + async fn open_named_document(&mut self, text: impl Display, document_name: Url) -> Result<()> { + self.notify( + "textDocument/didOpen", + DidOpenTextDocumentParams { + text_document: TextDocumentItem { + uri: document_name, + language_id: String::from("sql"), + version: 0, + text: text.to_string(), + }, + }, + ) + .await + } + + /// When calling this function, remember to insert the file inside the memory file system + async fn load_configuration(&mut self) -> Result<()> { + self.notify( + "workspace/didChangeConfiguration", + DidChangeConfigurationParams { + settings: to_value(()).unwrap(), + }, + ) + .await + } + + async fn change_document( + &mut self, + version: i32, + content_changes: Vec, + ) -> Result<()> { + self.notify( + "textDocument/didChange", + DidChangeTextDocumentParams { + text_document: VersionedTextDocumentIdentifier { + uri: url!("document.sql"), + version, + }, + content_changes, + }, + ) + .await + } + + async fn close_document(&mut self) -> Result<()> { + self.notify( + "textDocument/didClose", + DidCloseTextDocumentParams { + text_document: TextDocumentIdentifier { + uri: url!("document.sql"), + }, + }, + ) + .await + } + + /// Basic implementation of the `pglt/shutdown` request for tests + async fn pglt_shutdown(&mut self) -> Result<()> { + self.request::<_, ()>("pglt/shutdown", "_pglt_shutdown", ()) + .await? + .context("pglt/shutdown returned None")?; + Ok(()) + } +} + +/// Number of notifications buffered by the server-to-client channel before it starts blocking the current task +const CHANNEL_BUFFER_SIZE: usize = 8; + +#[derive(Debug, PartialEq, Eq)] +enum ServerNotification { + PublishDiagnostics(PublishDiagnosticsParams), +} + +/// Basic handler for requests and notifications coming from the server for tests +async fn client_handler( + mut stream: I, + mut sink: O, + mut notify: Sender, +) -> Result<()> +where + // This function has to be generic as `RequestStream` and `ResponseSink` + // are not exported from `tower_lsp` and cannot be named in the signature + I: Stream + Unpin, + O: Sink + Unpin, +{ + while let Some(req) = stream.next().await { + if req.method() == "textDocument/publishDiagnostics" { + let params = req.params().expect("invalid request"); + let diagnostics = from_value(params.clone()).expect("invalid params"); + let notification = ServerNotification::PublishDiagnostics(diagnostics); + match notify.send(notification).await { + Ok(_) => continue, + Err(_) => break, + } + } + + let id = match req.id() { + Some(id) => id, + None => continue, + }; + + let res = Response::from_error(id.clone(), jsonrpc::Error::method_not_found()); + + sink.send(res).await.ok(); + } + + Ok(()) +} + +#[tokio::test] +async fn basic_lifecycle() -> Result<()> { + let factory = ServerFactory::default(); + let (service, client) = factory.create(None).into_inner(); + let (stream, sink) = client.split(); + let mut server = Server::new(service); + + let (sender, _) = channel(CHANNEL_BUFFER_SIZE); + let reader = tokio::spawn(client_handler(stream, sink, sender)); + + server.initialize().await?; + server.initialized().await?; + + server.shutdown().await?; + reader.abort(); + + Ok(()) +}