Skip to content

Commit ac0a70f

Browse files
committed
[mlir] Split out Python bindings entry point into a separate file
This will allow the bindings to be built as a library and reused in out-of-tree projects that want to provide bindings on top of MLIR bindings. Reviewed By: stellaraccident, mikeurbach Differential Revision: https://reviews.llvm.org/D101075
1 parent 54ee962 commit ac0a70f

File tree

3 files changed

+147
-129
lines changed

3 files changed

+147
-129
lines changed

mlir/lib/Bindings/Python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
8484
IRAffine.cpp
8585
IRAttributes.cpp
8686
IRCore.cpp
87+
IRModule.cpp
8788
IRTypes.cpp
8889
PybindUtils.cpp
8990
Pass.cpp

mlir/lib/Bindings/Python/IRModule.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//===- IRModule.cpp - IR pybind module ------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "IRModule.h"
10+
#include "Globals.h"
11+
#include "PybindUtils.h"
12+
13+
#include <vector>
14+
15+
namespace py = pybind11;
16+
using namespace mlir;
17+
using namespace mlir::python;
18+
19+
// -----------------------------------------------------------------------------
20+
// PyGlobals
21+
// -----------------------------------------------------------------------------
22+
23+
PyGlobals *PyGlobals::instance = nullptr;
24+
25+
PyGlobals::PyGlobals() {
26+
assert(!instance && "PyGlobals already constructed");
27+
instance = this;
28+
}
29+
30+
PyGlobals::~PyGlobals() { instance = nullptr; }
31+
32+
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
33+
py::gil_scoped_acquire();
34+
if (loadedDialectModulesCache.contains(dialectNamespace))
35+
return;
36+
// Since re-entrancy is possible, make a copy of the search prefixes.
37+
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
38+
py::object loaded;
39+
for (std::string moduleName : localSearchPrefixes) {
40+
moduleName.push_back('.');
41+
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
42+
43+
try {
44+
py::gil_scoped_release();
45+
loaded = py::module::import(moduleName.c_str());
46+
} catch (py::error_already_set &e) {
47+
if (e.matches(PyExc_ModuleNotFoundError)) {
48+
continue;
49+
} else {
50+
throw;
51+
}
52+
}
53+
break;
54+
}
55+
56+
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
57+
// may have occurred, which may do anything.
58+
loadedDialectModulesCache.insert(dialectNamespace);
59+
}
60+
61+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
62+
py::object pyClass) {
63+
py::gil_scoped_acquire();
64+
py::object &found = dialectClassMap[dialectNamespace];
65+
if (found) {
66+
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
67+
dialectNamespace +
68+
"' is already registered.");
69+
}
70+
found = std::move(pyClass);
71+
}
72+
73+
void PyGlobals::registerOperationImpl(const std::string &operationName,
74+
py::object pyClass,
75+
py::object rawOpViewClass) {
76+
py::gil_scoped_acquire();
77+
py::object &found = operationClassMap[operationName];
78+
if (found) {
79+
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
80+
operationName +
81+
"' is already registered.");
82+
}
83+
found = std::move(pyClass);
84+
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
85+
}
86+
87+
llvm::Optional<py::object>
88+
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
89+
py::gil_scoped_acquire();
90+
loadDialectModule(dialectNamespace);
91+
// Fast match against the class map first (common case).
92+
const auto foundIt = dialectClassMap.find(dialectNamespace);
93+
if (foundIt != dialectClassMap.end()) {
94+
if (foundIt->second.is_none())
95+
return llvm::None;
96+
assert(foundIt->second && "py::object is defined");
97+
return foundIt->second;
98+
}
99+
100+
// Not found and loading did not yield a registration. Negative cache.
101+
dialectClassMap[dialectNamespace] = py::none();
102+
return llvm::None;
103+
}
104+
105+
llvm::Optional<pybind11::object>
106+
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
107+
{
108+
py::gil_scoped_acquire();
109+
auto foundIt = rawOpViewClassMapCache.find(operationName);
110+
if (foundIt != rawOpViewClassMapCache.end()) {
111+
if (foundIt->second.is_none())
112+
return llvm::None;
113+
assert(foundIt->second && "py::object is defined");
114+
return foundIt->second;
115+
}
116+
}
117+
118+
// Not found. Load the dialect namespace.
119+
auto split = operationName.split('.');
120+
llvm::StringRef dialectNamespace = split.first;
121+
loadDialectModule(dialectNamespace);
122+
123+
// Attempt to find from the canonical map and cache.
124+
{
125+
py::gil_scoped_acquire();
126+
auto foundIt = rawOpViewClassMap.find(operationName);
127+
if (foundIt != rawOpViewClassMap.end()) {
128+
if (foundIt->second.is_none())
129+
return llvm::None;
130+
assert(foundIt->second && "py::object is defined");
131+
// Positive cache.
132+
rawOpViewClassMapCache[operationName] = foundIt->second;
133+
return foundIt->second;
134+
} else {
135+
// Negative cache.
136+
rawOpViewClassMap[operationName] = py::none();
137+
return llvm::None;
138+
}
139+
}
140+
}
141+
142+
void PyGlobals::clearImportCache() {
143+
py::gil_scoped_acquire();
144+
loadedDialectModulesCache.clear();
145+
rawOpViewClassMapCache.clear();
146+
}

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -20,135 +20,6 @@ namespace py = pybind11;
2020
using namespace mlir;
2121
using namespace mlir::python;
2222

23-
// -----------------------------------------------------------------------------
24-
// PyGlobals
25-
// -----------------------------------------------------------------------------
26-
27-
PyGlobals *PyGlobals::instance = nullptr;
28-
29-
PyGlobals::PyGlobals() {
30-
assert(!instance && "PyGlobals already constructed");
31-
instance = this;
32-
}
33-
34-
PyGlobals::~PyGlobals() { instance = nullptr; }
35-
36-
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
37-
py::gil_scoped_acquire();
38-
if (loadedDialectModulesCache.contains(dialectNamespace))
39-
return;
40-
// Since re-entrancy is possible, make a copy of the search prefixes.
41-
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
42-
py::object loaded;
43-
for (std::string moduleName : localSearchPrefixes) {
44-
moduleName.push_back('.');
45-
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
46-
47-
try {
48-
py::gil_scoped_release();
49-
loaded = py::module::import(moduleName.c_str());
50-
} catch (py::error_already_set &e) {
51-
if (e.matches(PyExc_ModuleNotFoundError)) {
52-
continue;
53-
} else {
54-
throw;
55-
}
56-
}
57-
break;
58-
}
59-
60-
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
61-
// may have occurred, which may do anything.
62-
loadedDialectModulesCache.insert(dialectNamespace);
63-
}
64-
65-
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
66-
py::object pyClass) {
67-
py::gil_scoped_acquire();
68-
py::object &found = dialectClassMap[dialectNamespace];
69-
if (found) {
70-
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
71-
dialectNamespace +
72-
"' is already registered.");
73-
}
74-
found = std::move(pyClass);
75-
}
76-
77-
void PyGlobals::registerOperationImpl(const std::string &operationName,
78-
py::object pyClass,
79-
py::object rawOpViewClass) {
80-
py::gil_scoped_acquire();
81-
py::object &found = operationClassMap[operationName];
82-
if (found) {
83-
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
84-
operationName +
85-
"' is already registered.");
86-
}
87-
found = std::move(pyClass);
88-
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
89-
}
90-
91-
llvm::Optional<py::object>
92-
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
93-
py::gil_scoped_acquire();
94-
loadDialectModule(dialectNamespace);
95-
// Fast match against the class map first (common case).
96-
const auto foundIt = dialectClassMap.find(dialectNamespace);
97-
if (foundIt != dialectClassMap.end()) {
98-
if (foundIt->second.is_none())
99-
return llvm::None;
100-
assert(foundIt->second && "py::object is defined");
101-
return foundIt->second;
102-
}
103-
104-
// Not found and loading did not yield a registration. Negative cache.
105-
dialectClassMap[dialectNamespace] = py::none();
106-
return llvm::None;
107-
}
108-
109-
llvm::Optional<pybind11::object>
110-
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
111-
{
112-
py::gil_scoped_acquire();
113-
auto foundIt = rawOpViewClassMapCache.find(operationName);
114-
if (foundIt != rawOpViewClassMapCache.end()) {
115-
if (foundIt->second.is_none())
116-
return llvm::None;
117-
assert(foundIt->second && "py::object is defined");
118-
return foundIt->second;
119-
}
120-
}
121-
122-
// Not found. Load the dialect namespace.
123-
auto split = operationName.split('.');
124-
llvm::StringRef dialectNamespace = split.first;
125-
loadDialectModule(dialectNamespace);
126-
127-
// Attempt to find from the canonical map and cache.
128-
{
129-
py::gil_scoped_acquire();
130-
auto foundIt = rawOpViewClassMap.find(operationName);
131-
if (foundIt != rawOpViewClassMap.end()) {
132-
if (foundIt->second.is_none())
133-
return llvm::None;
134-
assert(foundIt->second && "py::object is defined");
135-
// Positive cache.
136-
rawOpViewClassMapCache[operationName] = foundIt->second;
137-
return foundIt->second;
138-
} else {
139-
// Negative cache.
140-
rawOpViewClassMap[operationName] = py::none();
141-
return llvm::None;
142-
}
143-
}
144-
}
145-
146-
void PyGlobals::clearImportCache() {
147-
py::gil_scoped_acquire();
148-
loadedDialectModulesCache.clear();
149-
rawOpViewClassMapCache.clear();
150-
}
151-
15223
// -----------------------------------------------------------------------------
15324
// Module initialization.
15425
// -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)