diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h new file mode 100644 index 0000000000000..4811ecb5e92b7 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h @@ -0,0 +1,30 @@ +//===- MlirLspRegistryFunction.h - LSP registry functions -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry function types for MLIR LSP. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H +#define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H + +namespace llvm { +template +class function_ref; +} // namespace llvm + +namespace mlir { +class DialectRegistry; +namespace lsp { +class URIForFile; +using DialectRegistryFn = + llvm::function_ref; +} // namespace lsp +} // namespace mlir + +#endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h index 66d5d40a6d28d..a461fc4702946 100644 --- a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h @@ -12,20 +12,27 @@ #ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H #define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H +#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h" namespace llvm { struct LogicalResult; } // namespace llvm namespace mlir { -class DialectRegistry; /// Implementation for tools like `mlir-lsp-server`. /// - registry should contain all the dialects that can be parsed in source IR -/// passed to the server. +/// passed to the server. llvm::LogicalResult MlirLspServerMain(int argc, char **argv, DialectRegistry ®istry); +/// Implementation for tools like `mlir-lsp-server`. +/// - registry should contain all the dialects that can be parsed in source IR +/// passed to the server and may register different dialects depending on the +/// input URI. +llvm::LogicalResult MlirLspServerMain(int argc, char **argv, + lsp::DialectRegistryFn registry_fn); + } // namespace mlir #endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 4e19274c3da40..61987525a5ca5 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -997,7 +997,7 @@ namespace { class MLIRTextFile { public: MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, - int64_t version, DialectRegistry ®istry, + int64_t version, lsp::DialectRegistryFn registry_fn, std::vector &diagnostics); /// Return the current version of this text file. @@ -1046,9 +1046,9 @@ class MLIRTextFile { } // namespace MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents, - int64_t version, DialectRegistry ®istry, + int64_t version, lsp::DialectRegistryFn registry_fn, std::vector &diagnostics) - : context(registry, MLIRContext::Threading::DISABLED), + : context(registry_fn(uri), MLIRContext::Threading::DISABLED), contents(fileContents.str()), version(version) { context.allowUnregisteredDialects(); @@ -1263,11 +1263,11 @@ MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { //===----------------------------------------------------------------------===// struct lsp::MLIRServer::Impl { - Impl(DialectRegistry ®istry) : registry(registry) {} + Impl(lsp::DialectRegistryFn registry_fn) : registry_fn(registry_fn) {} - /// The registry containing dialects that can be recognized in parsed .mlir - /// files. - DialectRegistry ®istry; + /// The registry factory for containing dialects that can be recognized in + /// parsed .mlir files. + lsp::DialectRegistryFn registry_fn; /// The files held by the server, mapped by their URI file name. llvm::StringMap> files; @@ -1277,15 +1277,15 @@ struct lsp::MLIRServer::Impl { // MLIRServer //===----------------------------------------------------------------------===// -lsp::MLIRServer::MLIRServer(DialectRegistry ®istry) - : impl(std::make_unique(registry)) {} +lsp::MLIRServer::MLIRServer(lsp::DialectRegistryFn registry_fn) + : impl(std::make_unique(registry_fn)) {} lsp::MLIRServer::~MLIRServer() = default; void lsp::MLIRServer::addOrUpdateDocument( const URIForFile &uri, StringRef contents, int64_t version, std::vector &diagnostics) { impl->files[uri.file()] = std::make_unique( - uri, contents, version, impl->registry, diagnostics); + uri, contents, version, impl->registry_fn, diagnostics); } std::optional lsp::MLIRServer::removeDocument(const URIForFile &uri) { @@ -1348,7 +1348,7 @@ void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos, llvm::Expected lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { - MLIRContext tempContext(impl->registry); + MLIRContext tempContext(impl->registry_fn(uri)); tempContext.allowUnregisteredDialects(); // Collect any errors during parsing. diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h index 979be615b82cc..85e69e69f6631 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -10,6 +10,7 @@ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ #include "mlir/Support/LLVM.h" +#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h" #include "llvm/Support/Error.h" #include #include @@ -28,15 +29,14 @@ struct Location; struct MLIRConvertBytecodeResult; struct Position; struct Range; -class URIForFile; /// This class implements all of the MLIR related functionality necessary for a /// language server. This class allows for keeping the MLIR specific logic /// separate from the logic that involves LSP server/client communication. class MLIRServer { public: - /// Construct a new server with the given dialect regitstry. - MLIRServer(DialectRegistry ®istry); + /// Construct a new server with the given dialect registry function. + MLIRServer(DialectRegistryFn registry_fn); ~MLIRServer(); /// Add or update the document, with the provided `version`, at the given URI. diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp index 259bd2613a6cc..b1bbf98ce769e 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -19,7 +19,7 @@ using namespace mlir; using namespace mlir::lsp; LogicalResult mlir::MlirLspServerMain(int argc, char **argv, - DialectRegistry ®istry) { + DialectRegistryFn registry_fn) { llvm::cl::opt inputStyle{ "input-style", llvm::cl::desc("Input JSON stream encoding"), @@ -72,6 +72,15 @@ LogicalResult mlir::MlirLspServerMain(int argc, char **argv, URIForFile::registerSupportedScheme("mlir.bytecode-mlir"); // Configure the servers and start the main language server. - MLIRServer server(registry); + MLIRServer server(registry_fn); return runMlirLSPServer(server, transport); } + +llvm::LogicalResult mlir::MlirLspServerMain(int argc, char **argv, + DialectRegistry ®istry) { + auto registry_fn = + [®istry](const lsp::URIForFile &uri) -> DialectRegistry & { + return registry; + }; + return MlirLspServerMain(argc, argv, registry_fn); +} diff --git a/mlir/test/mlir-lsp-server/uri-based-registration.test b/mlir/test/mlir-lsp-server/uri-based-registration.test new file mode 100644 index 0000000000000..d6d06692a8d7a --- /dev/null +++ b/mlir/test/mlir-lsp-server/uri-based-registration.test @@ -0,0 +1,23 @@ +// RUN: not mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}} +// ----- +// Just regular parse, successful. +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo-regular-registration.mlir", + "languageId":"mlir", + "version":1, + "text":"func.func @fail_with_empty_registry() { return }" +}}} +// CHECK: "method": "textDocument/publishDiagnostics", +// CHECK: "diagnostics": [] +// ----- +// Just regular parse, successful. +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo-disable-lsp-registration.mlir", + "languageId":"mlir", + "version":1, + "text":"func.func @fail_with_empty_registry() { return }" +}}} +// CHECK: "method": "textDocument/publishDiagnostics", +// CHECK: "message": "Dialect `func' not found for custom op 'func.func' + diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index f0ecc5adc68b3..6a759d9e0d60f 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" +#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" using namespace mlir; @@ -23,7 +23,7 @@ void registerTestTransformDialectExtension(DialectRegistry &); #endif int main(int argc, char **argv) { - DialectRegistry registry; + DialectRegistry registry, empty; registerAllDialects(registry); registerAllExtensions(registry); @@ -32,5 +32,18 @@ int main(int argc, char **argv) { ::test::registerTestTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); #endif - return failed(MlirLspServerMain(argc, argv, registry)); + + // Returns the registry, except in testing mode when the URI contains + // "-disable-lsp-registration". Testing for/example of registering dialects + // based on URI. + auto registryFn = [®istry, + &empty](const lsp::URIForFile &uri) -> DialectRegistry & { + (void)empty; +#ifdef MLIR_INCLUDE_TESTS + if (uri.uri().contains("-disable-lsp-registration")) + return empty; +#endif + return registry; + }; + return failed(MlirLspServerMain(argc, argv, registryFn)); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 1cbc632952c4d..8b1c76717e2f2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8835,11 +8835,51 @@ cc_binary( name = "mlir-lsp-server", srcs = ["tools/mlir-lsp-server/mlir-lsp-server.cpp"], includes = ["include"], + local_defines = ["MLIR_INCLUDE_TESTS"], deps = [ ":AllExtensions", ":AllPassesAndDialects", ":IR", ":MlirLspServerLib", + ":MlirLspServerSupportLib", + "//mlir/test:TestAffine", + "//mlir/test:TestAnalysis", + "//mlir/test:TestArith", + "//mlir/test:TestArmNeon", + "//mlir/test:TestArmSME", + "//mlir/test:TestBufferization", + "//mlir/test:TestControlFlow", + "//mlir/test:TestConvertToSPIRV", + "//mlir/test:TestDLTI", + "//mlir/test:TestDialect", + "//mlir/test:TestFunc", + "//mlir/test:TestFuncToLLVM", + "//mlir/test:TestGPU", + "//mlir/test:TestIR", + "//mlir/test:TestLLVM", + "//mlir/test:TestLinalg", + "//mlir/test:TestLoopLikeInterface", + "//mlir/test:TestMath", + "//mlir/test:TestMathToVCIX", + "//mlir/test:TestMemRef", + "//mlir/test:TestMesh", + "//mlir/test:TestNVGPU", + "//mlir/test:TestPDLL", + "//mlir/test:TestPass", + "//mlir/test:TestReducer", + "//mlir/test:TestRewrite", + "//mlir/test:TestSCF", + "//mlir/test:TestSPIRV", + "//mlir/test:TestShapeDialect", + "//mlir/test:TestTensor", + "//mlir/test:TestTestDynDialect", + "//mlir/test:TestTilingInterface", + "//mlir/test:TestTosaDialect", + "//mlir/test:TestTransformDialect", + "//mlir/test:TestTransforms", + "//mlir/test:TestVector", + "//mlir/test:TestVectorToSPIRV", + "//mlir/test:TestXeGPU", ], )