Skip to content

[mlir][lsp] Enable registering dialects based on URI. #141331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- MlirLspRegistryFunction.h - LSP registry functions -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Registry function types for MLIR LSP.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H
#define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H

namespace llvm {
template <typename Fn>
class function_ref;
} // namespace llvm

namespace mlir {
class DialectRegistry;
namespace lsp {
class URIForFile;
using DialectRegistryFn =
llvm::function_ref<DialectRegistry &(const URIForFile &uri)>;
} // namespace lsp
} // namespace mlir

#endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPREGISTRYFUNCTION_H
11 changes: 9 additions & 2 deletions mlir/include/mlir/Tools/mlir-lsp-server/MlirLspServerMain.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,27 @@

#ifndef MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H
#define MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"

namespace llvm {
struct LogicalResult;
} // namespace llvm

namespace mlir {
class DialectRegistry;

/// Implementation for tools like `mlir-lsp-server`.
/// - registry should contain all the dialects that can be parsed in source IR
/// passed to the server.
/// passed to the server.
llvm::LogicalResult MlirLspServerMain(int argc, char **argv,
DialectRegistry &registry);

/// Implementation for tools like `mlir-lsp-server`.
/// - registry should contain all the dialects that can be parsed in source IR
/// passed to the server and may register different dialects depending on the
/// input URI.
llvm::LogicalResult MlirLspServerMain(int argc, char **argv,
lsp::DialectRegistryFn registry_fn);

} // namespace mlir

#endif // MLIR_TOOLS_MLIR_LSP_SERVER_MLIRLSPSERVERMAIN_H
22 changes: 11 additions & 11 deletions mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ namespace {
class MLIRTextFile {
public:
MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
int64_t version, DialectRegistry &registry,
int64_t version, lsp::DialectRegistryFn registry_fn,
std::vector<lsp::Diagnostic> &diagnostics);

/// Return the current version of this text file.
Expand Down Expand Up @@ -1046,9 +1046,9 @@ class MLIRTextFile {
} // namespace

MLIRTextFile::MLIRTextFile(const lsp::URIForFile &uri, StringRef fileContents,
int64_t version, DialectRegistry &registry,
int64_t version, lsp::DialectRegistryFn registry_fn,
std::vector<lsp::Diagnostic> &diagnostics)
: context(registry, MLIRContext::Threading::DISABLED),
: context(registry_fn(uri), MLIRContext::Threading::DISABLED),
contents(fileContents.str()), version(version) {
context.allowUnregisteredDialects();

Expand Down Expand Up @@ -1263,11 +1263,11 @@ MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) {
//===----------------------------------------------------------------------===//

struct lsp::MLIRServer::Impl {
Impl(DialectRegistry &registry) : registry(registry) {}
Impl(lsp::DialectRegistryFn registry_fn) : registry_fn(registry_fn) {}

/// The registry containing dialects that can be recognized in parsed .mlir
/// files.
DialectRegistry &registry;
/// The registry factory for containing dialects that can be recognized in
/// parsed .mlir files.
lsp::DialectRegistryFn registry_fn;

/// The files held by the server, mapped by their URI file name.
llvm::StringMap<std::unique_ptr<MLIRTextFile>> files;
Expand All @@ -1277,15 +1277,15 @@ struct lsp::MLIRServer::Impl {
// MLIRServer
//===----------------------------------------------------------------------===//

lsp::MLIRServer::MLIRServer(DialectRegistry &registry)
: impl(std::make_unique<Impl>(registry)) {}
lsp::MLIRServer::MLIRServer(lsp::DialectRegistryFn registry_fn)
: impl(std::make_unique<Impl>(registry_fn)) {}
lsp::MLIRServer::~MLIRServer() = default;

void lsp::MLIRServer::addOrUpdateDocument(
const URIForFile &uri, StringRef contents, int64_t version,
std::vector<Diagnostic> &diagnostics) {
impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
uri, contents, version, impl->registry, diagnostics);
uri, contents, version, impl->registry_fn, diagnostics);
}

std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
Expand Down Expand Up @@ -1348,7 +1348,7 @@ void lsp::MLIRServer::getCodeActions(const URIForFile &uri, const Range &pos,

llvm::Expected<lsp::MLIRConvertBytecodeResult>
lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
MLIRContext tempContext(impl->registry);
MLIRContext tempContext(impl->registry_fn(uri));
tempContext.allowUnregisteredDialects();

// Collect any errors during parsing.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Tools/mlir-lsp-server/MLIRServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_

#include "mlir/Support/LLVM.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"
#include "llvm/Support/Error.h"
#include <memory>
#include <optional>
Expand All @@ -28,15 +29,14 @@ struct Location;
struct MLIRConvertBytecodeResult;
struct Position;
struct Range;
class URIForFile;

/// This class implements all of the MLIR related functionality necessary for a
/// language server. This class allows for keeping the MLIR specific logic
/// separate from the logic that involves LSP server/client communication.
class MLIRServer {
public:
/// Construct a new server with the given dialect regitstry.
MLIRServer(DialectRegistry &registry);
/// Construct a new server with the given dialect registry function.
MLIRServer(DialectRegistryFn registry_fn);
~MLIRServer();

/// Add or update the document, with the provided `version`, at the given URI.
Expand Down
13 changes: 11 additions & 2 deletions mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace mlir;
using namespace mlir::lsp;

LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
DialectRegistry &registry) {
DialectRegistryFn registry_fn) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",
llvm::cl::desc("Input JSON stream encoding"),
Expand Down Expand Up @@ -72,6 +72,15 @@ LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
URIForFile::registerSupportedScheme("mlir.bytecode-mlir");

// Configure the servers and start the main language server.
MLIRServer server(registry);
MLIRServer server(registry_fn);
return runMlirLSPServer(server, transport);
}

llvm::LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
DialectRegistry &registry) {
auto registry_fn =
[&registry](const lsp::URIForFile &uri) -> DialectRegistry & {
return registry;
};
return MlirLspServerMain(argc, argv, registry_fn);
}
23 changes: 23 additions & 0 deletions mlir/test/mlir-lsp-server/uri-based-registration.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: not mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}}
// -----
// Just regular parse, successful.
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo-regular-registration.mlir",
"languageId":"mlir",
"version":1,
"text":"func.func @fail_with_empty_registry() { return }"
}}}
// CHECK: "method": "textDocument/publishDiagnostics",
// CHECK: "diagnostics": []
// -----
// Just regular parse, successful.
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo-disable-lsp-registration.mlir",
"languageId":"mlir",
"version":1,
"text":"func.func @fail_with_empty_registry() { return }"
}}}
// CHECK: "method": "textDocument/publishDiagnostics",
// CHECK: "message": "Dialect `func' not found for custom op 'func.func'

19 changes: 16 additions & 3 deletions mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"

using namespace mlir;
Expand All @@ -23,7 +23,7 @@ void registerTestTransformDialectExtension(DialectRegistry &);
#endif

int main(int argc, char **argv) {
DialectRegistry registry;
DialectRegistry registry, empty;
registerAllDialects(registry);
registerAllExtensions(registry);

Expand All @@ -32,5 +32,18 @@ int main(int argc, char **argv) {
::test::registerTestTransformDialectExtension(registry);
::test::registerTestDynDialect(registry);
#endif
return failed(MlirLspServerMain(argc, argv, registry));

// Returns the registry, except in testing mode when the URI contains
// "-disable-lsp-registration". Testing for/example of registering dialects
// based on URI.
auto registryFn = [&registry,
&empty](const lsp::URIForFile &uri) -> DialectRegistry & {
(void)empty;
#ifdef MLIR_INCLUDE_TESTS
if (uri.uri().contains("-disable-lsp-registration"))
return empty;
#endif
return registry;
};
return failed(MlirLspServerMain(argc, argv, registryFn));
}
40 changes: 40 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8835,11 +8835,51 @@ cc_binary(
name = "mlir-lsp-server",
srcs = ["tools/mlir-lsp-server/mlir-lsp-server.cpp"],
includes = ["include"],
local_defines = ["MLIR_INCLUDE_TESTS"],
deps = [
":AllExtensions",
":AllPassesAndDialects",
":IR",
":MlirLspServerLib",
":MlirLspServerSupportLib",
"//mlir/test:TestAffine",
"//mlir/test:TestAnalysis",
"//mlir/test:TestArith",
"//mlir/test:TestArmNeon",
"//mlir/test:TestArmSME",
"//mlir/test:TestBufferization",
"//mlir/test:TestControlFlow",
"//mlir/test:TestConvertToSPIRV",
"//mlir/test:TestDLTI",
"//mlir/test:TestDialect",
"//mlir/test:TestFunc",
"//mlir/test:TestFuncToLLVM",
"//mlir/test:TestGPU",
"//mlir/test:TestIR",
"//mlir/test:TestLLVM",
"//mlir/test:TestLinalg",
"//mlir/test:TestLoopLikeInterface",
"//mlir/test:TestMath",
"//mlir/test:TestMathToVCIX",
"//mlir/test:TestMemRef",
"//mlir/test:TestMesh",
"//mlir/test:TestNVGPU",
"//mlir/test:TestPDLL",
"//mlir/test:TestPass",
"//mlir/test:TestReducer",
"//mlir/test:TestRewrite",
"//mlir/test:TestSCF",
"//mlir/test:TestSPIRV",
"//mlir/test:TestShapeDialect",
"//mlir/test:TestTensor",
"//mlir/test:TestTestDynDialect",
"//mlir/test:TestTilingInterface",
"//mlir/test:TestTosaDialect",
"//mlir/test:TestTransformDialect",
"//mlir/test:TestTransforms",
"//mlir/test:TestVector",
"//mlir/test:TestVectorToSPIRV",
"//mlir/test:TestXeGPU",
],
)

Expand Down