Skip to content

Commit 629da39

Browse files
committed
improve matrix class
1 parent 59d56c9 commit 629da39

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

cp-algo/linalg/matrix.hpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,44 @@ namespace cp_algo::linalg {
2626

2727
matrix(Base const& t): Base(t) {}
2828
matrix(Base &&t): Base(std::move(t)) {}
29-
30-
static matrix from(auto &&r) {
31-
return std::ranges::to<Base>(r);
32-
}
29+
30+
template<std::ranges::input_range R>
31+
matrix(R &&r): Base(std::ranges::to<Base>(std::forward<R>(r))) {}
3332

3433
size_t n() const {return size(*this);}
3534
size_t m() const {return n() ? size(row(0)) : 0;}
36-
auto dim() const {return std::array{n(), m()};}
35+
36+
void resize(size_t n, size_t m) {
37+
Base::resize(n);
38+
for(auto &it: *this) {
39+
it.resize(m);
40+
}
41+
}
3742

3843
auto& row(size_t i) {return (*this)[i];}
3944
auto const& row(size_t i) const {return (*this)[i];}
4045

46+
auto elements() {return *this | std::views::join;}
47+
auto elements() const {return *this | std::views::join;}
4148

42-
auto operator-() const {
43-
return from(*this | std::views::transform([](auto x) {return vec_t(-x);}));
49+
matrix operator-() const {
50+
return *this | std::views::transform([](auto x) {return vec_t(-x);});
4451
}
52+
matrix& operator+=(matrix const& t) {
53+
for(auto [a, b]: std::views::zip(elements(), t.elements())) {
54+
a += b;
55+
}
56+
return *this;
57+
}
58+
matrix& operator -=(matrix const& t) {
59+
for(auto [a, b]: std::views::zip(elements(), t.elements())) {
60+
a -= b;
61+
}
62+
return *this;
63+
}
64+
matrix operator+(matrix const& t) const {return matrix(*this) += t;}
65+
matrix operator-(matrix const& t) const {return matrix(*this) -= t;}
66+
4567
matrix& operator *=(base t) {for(auto &it: *this) it *= t; return *this;}
4668
matrix operator *(base t) const {return matrix(*this) *= t;}
4769
matrix& operator /=(base t) {return *this *= base(1) / t;}
@@ -109,6 +131,11 @@ namespace cp_algo::linalg {
109131
}
110132
return res;
111133
}
134+
void assign_submatrix(auto viewx, auto viewy, matrix const& t) {
135+
for(auto [a, b]: std::views::zip(*this | viewx, t)) {
136+
std::ranges::copy(b, begin(a | viewy));
137+
}
138+
}
112139
auto submatrix(auto viewx, auto viewy) const {
113140
return *this | viewx | std::views::transform([viewy](auto const& y) {
114141
return y | viewy;
@@ -214,7 +241,7 @@ namespace cp_algo::linalg {
214241
det *= b[i][i];
215242
b[i] *= base(1) / b[i][i];
216243
}
217-
return {det, from(b.submatrix(std::views::all, std::views::drop(n())))};
244+
return {det, b.submatrix(std::views::all, std::views::drop(n()))};
218245
}
219246

220247
// Can also just run gauss on T() | eye(m)
@@ -238,15 +265,15 @@ namespace cp_algo::linalg {
238265
// [solution, basis], transposed
239266
std::optional<std::array<matrix, 2>> solve(matrix t) const {
240267
matrix sols = (*this | t).kernel();
241-
if(sols.n() < t.m() || from(sols.submatrix(
268+
if(sols.n() < t.m() || matrix(sols.submatrix(
242269
std::views::drop(sols.n() - t.m()),
243270
std::views::drop(m())
244271
)) != -eye(t.m())) {
245272
return std::nullopt;
246273
} else {
247274
return std::array{
248-
from(sols.submatrix(std::views::drop(sols.n() - t.m()), std::views::take(m()))),
249-
from(sols.submatrix(std::views::take(sols.n() - t.m()), std::views::take(m())))
275+
matrix(sols.submatrix(std::views::drop(sols.n() - t.m()), std::views::take(m()))),
276+
matrix(sols.submatrix(std::views::take(sols.n() - t.m()), std::views::take(m())))
250277
};
251278
}
252279
}

0 commit comments

Comments
 (0)