Skip to content

Commit d261c7d

Browse files
authored
Merge pull request #16 from soblin/develop
Develop
2 parents 7e1ab08 + 0bfe4d4 commit d261c7d

File tree

10 files changed

+169
-23
lines changed

10 files changed

+169
-23
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ find_package(pybind11 2.4.3 REQUIRED)
1414
# check matplotlib minor version
1515
execute_process(
1616
COMMAND
17-
"python3" "-c"
17+
${Python3_EXECUTABLE} "-c"
1818
"import matplotlib;
1919
print(str(matplotlib.__version__))"
2020
RESULT_VARIABLE MATPLOTLIB_VERSION_CHECKING

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ It is supposed to provide the user with almost full access to matplotlib feature
1313
- [pybind11](https://github.com/pybind/pybind11) >= 2.4.3
1414
- `sudo apt install pybind11-dev` (on Ubuntu20.04)
1515
- or manual install
16-
- compatible with [matplotlib](https://matplotlib.org/stable/index.html) == 3.5.1
16+
- [matplotlib](https://matplotlib.org/stable/index.html) >= 3.4.0
1717
- numpy for `mplot3d`
1818
- ([xtensor](https://github.com/xtensor-stack/xtensor) == 0.24.0 + [xtl](https://github.com/xtensor-stack/xtl), only for `gallery` demos)
1919

gallery/shapes_and_collections/patch_collection.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,50 @@
33

44
#include <pybind11/embed.h>
55
#include <pybind11/stl.h>
6+
#include <pybind11/numpy.h>
67

78
#include <matplotlibcpp17/pyplot.h>
89
#include <matplotlibcpp17/patches.h>
910

10-
#include <algorithm>
11-
#include <iostream>
11+
#include <xtensor/xrandom.hpp>
12+
1213
#include <vector>
1314

1415
namespace py = pybind11;
1516
using namespace py::literals;
1617
using namespace std;
17-
using namespace matplotlibcpp17;
18+
using namespace matplotlibcpp17::patches;
1819
using namespace matplotlibcpp17;
1920

2021
int main() {
2122
py::scoped_interpreter guard{};
2223
auto plt = matplotlibcpp17::pyplot::import();
2324
auto [fig, ax] = plt.subplots();
2425

26+
const int resolution = 50;
2527
const int N = 3;
26-
vector<double> x = {0.7003673, 0.74275081, 0.70928001},
27-
y = {0.56674552, 0.97778533, 0.70633485},
28-
radii = {0.02479158, 0.01578834, 0.06976985};
28+
xt::xarray<double> x = xt::random::rand<double>({N});
29+
xt::xarray<double> y = xt::random::rand<double>({N});
30+
xt::xarray<double> radii = 0.1 * xt::random::rand<double>({N});
2931
py::list patches; // instead of patches = []
3032
for (int i = 0; i < N; ++i) {
3133
const double x1 = x[i], y1 = y[i], r = radii[i];
3234
auto circle = patches::Circle(Args(py::make_tuple(x1, y1), r));
3335
patches.append(circle.unwrap());
3436
}
35-
x = {0.71995667, 0.25774443, 0.34154678};
36-
y = {0.96876117, 0.6945071, 0.46638326};
37-
radii = {0.07028127, 0.05117859, 0.09287414};
38-
vector<double> theta1 = {266.3169476, 224.07805212, 234.5563688},
39-
theta2 = {142.85074015, 195.56618216, 287.96383014};
37+
38+
x = xt::random::rand<double>({N});
39+
y = xt::random::rand<double>({N});
40+
radii = 0.1 * xt::random::rand<double>({N});
41+
xt::xarray<double> theta1 = 360.0 * xt::random::rand<double>({N});
42+
xt::xarray<double> theta2 = 360.0 * xt::random::rand<double>({N});
4043
for (int i = 0; i < N; ++i) {
4144
const double x1 = x[i], y1 = y[i], r = radii[i], th1 = theta1[i],
4245
th2 = theta2[i];
4346
auto wedge = patches::Wedge(Args(py::make_tuple(x1, y1), r, th1, th2));
4447
patches.append(wedge.unwrap());
4548
}
49+
4650
patches.append(
4751
patches::Wedge(Args(py::make_tuple(0.3, 0.7), 0.1, 0, 360)).unwrap());
4852
patches.append(patches::Wedge(Args(py::make_tuple(0.7, 0.8), 0.2, 0, 360),
@@ -53,16 +57,30 @@ int main() {
5357
patches.append(patches::Wedge(Args(py::make_tuple(0.8, 0.3), 0.2, 45, 90),
5458
Kwargs("width"_a = 0.10))
5559
.unwrap());
56-
// NOTE: Polygon take numpy array as argument, so skip it
57-
vector<double> colors_ = {90.63036451, 16.10182093, 74.36211347, 63.29741618,
58-
32.41800177, 92.23765324, 23.72264387, 82.39455709,
59-
75.06071403, 11.37844527};
60-
py::list colors = py::cast(colors_);
60+
61+
for (int i = 0; i < N; ++i) {
62+
auto poly__ = xt::random::rand<double>({N, 2});
63+
vector<vector<double>> poly_(N);
64+
// to vector<vector>>
65+
for (int j = 0; j < N; ++j) {
66+
poly_[j].resize(2);
67+
poly_[j][0] = poly__(j, 0);
68+
poly_[j][1] = poly__(j, 1);
69+
}
70+
// to numpy array
71+
auto poly = py::array(py::cast(std::move(poly_)));
72+
auto polygon = Polygon(Args(poly, true));
73+
patches.append(polygon.unwrap());
74+
}
75+
76+
auto colors__ = 100.0 * xt::random::rand<double>({patches.size()});
77+
vector<double> colors_(colors__.begin(), colors__.end());
78+
py::array colors = py::cast(colors_);
6179
auto p = collections::PatchCollection(Args(patches), Kwargs("alpha"_a = 0.4));
6280
p.set_array(Args(colors));
63-
// NOTE: error in python3.6.9 ?
6481
ax.add_collection(Args(p.unwrap()));
6582
fig.colorbar(Args(p.unwrap()), Kwargs("ax"_a = ax.unwrap()));
83+
6684
#if USE_GUI
6785
plt.show();
6886
#else
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
add_demo(align_labels_demo align_labels_demo.cpp)
22
add_demo(gridspec_multicolumn gridspec_multicolumn.cpp)
33
add_demo(multiple_figs_demo multiple_figs_demo.cpp)
4+
add_demo(colorbar_placement colorbar_placement.cpp)
5+
add_demo(subplots subplots.cpp)
46

57
add_custom_target(subplots_axes_and_figures
6-
DEPENDS align_labels_demo gridspec_multicolumn multiple_figs_demo
8+
DEPENDS align_labels_demo gridspec_multicolumn multiple_figs_demo colorbar_placement
79
COMMAND align_labels_demo
810
COMMAND gridspec_multicolumn
911
COMMAND multiple_figs_demo
12+
COMMAND colorbar_placement
1013
COMMENT "subplots_axes_and_figures"
1114
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
1215
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// example from
2+
// https://matplotlib.org/stable/gallery/subplots_axes_and_figures/colorbar_placement.html
3+
4+
#include <pybind11/embed.h>
5+
#include <pybind11/stl.h>
6+
7+
#include <matplotlibcpp17/pyplot.h>
8+
9+
#include <xtensor/xrandom.hpp>
10+
11+
#include <vector>
12+
13+
namespace py = pybind11;
14+
using namespace py::literals;
15+
using namespace std;
16+
using namespace matplotlibcpp17;
17+
18+
int main1() {
19+
auto plt = matplotlibcpp17::pyplot::import();
20+
auto [fig, axs] = plt.subplots(2, 2);
21+
const vector<string> cmaps = {"RdBu_r", "viridis"};
22+
for (auto col : {0, 1}) {
23+
for (auto row : {0, 1}) {
24+
auto x_ = xt::random::randn<double>({20, 20}) * (col + 1.0);
25+
vector<vector<double>> x(20);
26+
for (int i = 0; i < 20; ++i) {
27+
x[i].resize(20);
28+
for (int j = 0; j < 20; ++j)
29+
x[i][j] = x_(i, j);
30+
}
31+
auto &ax = axs[col + row * 2];
32+
auto pcm = ax.pcolormesh(Args(x), Kwargs("cmap"_a = cmaps[col]));
33+
fig.colorbar(Args(pcm.unwrap()),
34+
Kwargs("ax"_a = ax.unwrap(), "shrink"_a = 0.6));
35+
}
36+
}
37+
#if USE_GUI
38+
plt.show();
39+
#else
40+
plt.savefig(Args("colorbar_placement1.png"));
41+
#endif
42+
return 0;
43+
}
44+
45+
int main() {
46+
py::scoped_interpreter guard{};
47+
main1();
48+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <pybind11/embed.h>
2+
#include <pybind11/stl.h>
3+
4+
#include <matplotlibcpp17/pyplot.h>
5+
6+
namespace py = pybind11;
7+
using namespace py::literals;
8+
using namespace std;
9+
using namespace matplotlibcpp17;
10+
11+
int main() {
12+
py::scoped_interpreter guard{};
13+
auto plt = matplotlibcpp17::pyplot::import();
14+
{
15+
auto [fig, axs] = plt.subplots(3, 1);
16+
std::cout << axs.size() << std::endl;
17+
}
18+
{
19+
auto [fig, axs] = plt.subplots(1, 1);
20+
std::cout << axs.size() << std::endl;
21+
}
22+
{
23+
auto [fig, axs] = plt.subplots(3, 3);
24+
std::cout << axs.size() << std::endl;
25+
}
26+
return 0;
27+
}

include/matplotlibcpp17/axes.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ struct DECL_STRUCT_ATTR Axes : public BaseWrapper {
123123
legend::Legend legend(const pybind11::tuple &args = pybind11::tuple(),
124124
const pybind11::dict &kwargs = pybind11::dict());
125125

126+
// pcolormesh
127+
collections::QuadMesh
128+
pcolormesh(const pybind11::tuple &args = pybind11::tuple(),
129+
const pybind11::dict &kwargs = pybind11::dict());
130+
126131
// plot
127132
pybind11::object plot(const pybind11::tuple &args = pybind11::tuple(),
128133
const pybind11::dict &kwargs = pybind11::dict());
@@ -230,6 +235,7 @@ struct DECL_STRUCT_ATTR Axes : public BaseWrapper {
230235
LOAD_FUNC_ATTR(hist2d, self);
231236
LOAD_FUNC_ATTR(invert_yaxis, self);
232237
LOAD_FUNC_ATTR(legend, self);
238+
LOAD_FUNC_ATTR(pcolormesh, self);
233239
LOAD_FUNC_ATTR(plot, self);
234240
// NOTE: only when called with projection='3d', `plot_surface`, `plot_wireframe`, `set_zlabel` prop exists.
235241
try {
@@ -278,6 +284,7 @@ struct DECL_STRUCT_ATTR Axes : public BaseWrapper {
278284
pybind11::object hist2d_attr;
279285
pybind11::object invert_yaxis_attr;
280286
pybind11::object legend_attr;
287+
pybind11::object pcolormesh_attr;
281288
pybind11::object plot_attr;
282289
pybind11::object plot_surface_attr;
283290
pybind11::object plot_wireframe_attr;
@@ -478,6 +485,13 @@ legend::Legend Axes::legend(const pybind11::tuple &args,
478485
return legend::Legend(obj);
479486
}
480487

488+
// pcolormesh
489+
collections::QuadMesh Axes::pcolormesh(const pybind11::tuple &args,
490+
const pybind11::dict &kwargs) {
491+
pybind11::object ret = pcolormesh_attr(*args, **kwargs);
492+
return collections::QuadMesh(ret);
493+
}
494+
481495
// plot
482496
pybind11::object Axes::plot(const pybind11::tuple &args,
483497
const pybind11::dict &kwargs) {

include/matplotlibcpp17/collections.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ pybind11::object PatchCollection::set_array(const pybind11::tuple &args,
7373
pybind11::object ret = set_array_attr(*args, **kwargs);
7474
return ret;
7575
}
76+
77+
/**
78+
* @brief A wrapper class for matplotlib.collections.QuadMesh
79+
**/
80+
struct DECL_STRUCT_ATTR QuadMesh : public BaseWrapper {
81+
public:
82+
QuadMesh(pybind11::object quadmesh) { self = quadmesh; }
83+
};
84+
7685
} // namespace matplotlibcpp17::collections
7786

7887
#endif /* MATPLOTLIBCPP17_COLLECTIONS_H */

include/matplotlibcpp17/patches.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ struct DECL_STRUCT_ATTR Wedge : public BaseWrapper {
7474
pybind11::object wedge_attr;
7575
};
7676

77+
/**
78+
* @brief A wrapper class for matplotlib.patches.Polygon
79+
**/
80+
struct DECL_STRUCT_ATTR Polygon : public BaseWrapper {
81+
public:
82+
Polygon(const pybind11::tuple &args = pybind11::tuple(),
83+
const pybind11::dict &kwargs = pybind11::dict()) {
84+
polygon_attr =
85+
pybind11::module::import("matplotlib.patches").attr("Polygon");
86+
self = polygon_attr(*args, **kwargs);
87+
}
88+
89+
private:
90+
pybind11::object polygon_attr;
91+
};
92+
7793
} // namespace matplotlibcpp17::patches
7894

7995
#endif /* MATPLOTLIBCPP17_PATCHES_H */

include/matplotlibcpp17/pyplot.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,29 @@ PyPlot::subplots(const pybind11::dict &kwargs) {
300300

301301
std::tuple<figure::Figure, std::vector<axes::Axes>>
302302
PyPlot::subplots(int r, int c, const pybind11::dict &kwargs) {
303+
// subplots() returns [][] (if r > 1 && c > 1) else []
304+
// return []axes in row-major
305+
// NOTE: equal to Axes.flat
303306
pybind11::tuple args = pybind11::make_tuple(r, c);
304307
pybind11::list ret = subplots_attr(*args, **kwargs);
305308
std::vector<axes::Axes> axes;
306309
pybind11::object fig = ret[0];
307310
figure::Figure figure(fig);
308311
if (r == 1 and c == 1) {
309-
pybind11::object ax = ret[1];
310-
axes.push_back(axes::Axes(ax));
312+
// python returns Axes
313+
axes.push_back(axes::Axes(ret[1]));
314+
} else if (r == 1 or c == 1) {
315+
// python returns []Axes
316+
pybind11::list axs = ret[1];
317+
for (int i = 0; i < r * c; ++i)
318+
axes.push_back(axes::Axes(axs[i]));
311319
} else {
320+
// python returns [][]Axes
312321
pybind11::list axs = ret[1];
313322
for (pybind11::size_t i = 0; i < axs.size(); ++i) {
314-
axes.push_back(axes::Axes(axs[i]));
323+
pybind11::list axsi = axs[i];
324+
for (unsigned j = 0; j < axsi.size(); ++j)
325+
axes.push_back(axes::Axes(axsi[j]));
315326
}
316327
}
317328
return {figure, axes};

0 commit comments

Comments
 (0)