Skip to content

Commit 0480f95

Browse files
authored
Implementing linear conjugate gradients (#62)
* Initial commit of the linear conjugate gradients * Updated __init__ so that we can import linear_cg * Since we are invoking the function as a class attribute, removing the self * Fixed the import * Cleaned up some unused code * Fixed the name of the function * If vector is not equal to itself, it might have NaN * removed temporary variable * Moved configuration from settings class to function arguments * Ensure output has same datatype regardless of value of n_tridiag * if n_tridiag = 0 then tridiagonal matrices are not computed instead an identity matrix is returned which doesn't affect downstream computation
1 parent 78b7f2f commit 0480f95

File tree

2 files changed

+277
-0
lines changed

2 files changed

+277
-0
lines changed

pymc_experimental/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from pymc_experimental.utils import prior, spline
2+
from pymc_experimental.utils.linear_cg import linear_cg

pymc_experimental/utils/linear_cg.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import numpy as np
2+
3+
EVAL_CG_TOLERANCE = 0.01
4+
CG_TOLERANCE = 1
5+
6+
7+
def masked_fill(vector, mask, fill_value):
8+
masked_vector = np.ma.array(vector, mask=mask)
9+
vector = masked_vector.filled(fill_value=fill_value)
10+
return vector
11+
12+
13+
def linear_cg_updates(
14+
result, alpha, residual_inner_prod, eps, beta, residual, precond_residual, curr_conjugate_vec
15+
):
16+
17+
# Everything inside _jit_linear_cg_updates
18+
result = result + alpha * curr_conjugate_vec
19+
beta = np.copy(residual_inner_prod)
20+
21+
residual_inner_prod = residual.T @ precond_residual
22+
23+
# safe division
24+
is_zero = beta < eps
25+
beta = masked_fill(beta, mask=is_zero, fill_value=1)
26+
27+
beta = residual_inner_prod / beta
28+
beta = masked_fill(beta, mask=is_zero, fill_value=0)
29+
curr_conjugate_vec = beta * curr_conjugate_vec + precond_residual
30+
return (
31+
result,
32+
alpha,
33+
residual_inner_prod,
34+
eps,
35+
beta,
36+
residual,
37+
precond_residual,
38+
curr_conjugate_vec,
39+
)
40+
41+
42+
def linear_cg(
43+
mat: np.matrix,
44+
rhs,
45+
n_tridiag=0,
46+
tolerance=None,
47+
eps=1e-10,
48+
stop_updating_after=1e-10,
49+
max_iter=1000,
50+
max_tridiag_iter=20,
51+
initial_guess=None,
52+
preconditioner=None,
53+
terminate_cg_by_size=False,
54+
use_eval_tolerange=False,
55+
):
56+
57+
if initial_guess is None:
58+
initial_guess = np.zeros_like(rhs)
59+
60+
if preconditioner is None:
61+
preconditioner = lambda x: x
62+
precond = False
63+
else:
64+
precond = True
65+
66+
if tolerance is None:
67+
if use_eval_tolerance:
68+
tolerance = EVAL_CG_TOLERANCE
69+
else:
70+
tolerance = CG_TOLERANCE
71+
72+
# If we are running m CG iterations, we obviously can't get more than m Lanczos coefficients
73+
if max_tridiag_iter > max_iter:
74+
raise RuntimeError(
75+
"Getting a tridiagonalization larger than the number of CG iterations run is not possible!"
76+
)
77+
78+
is_vector = len(rhs.shape) == 1
79+
if is_vector:
80+
rhs = rhs[:, np.newaxis]
81+
82+
num_rows = rhs.size
83+
n_iter = min(max_iter, num_rows) if terminate_cg_by_size else max_iter
84+
n_tridiag_iter = min(max_tridiag_iter, num_rows)
85+
86+
# norm of rhs for convergence tests
87+
rhs_norm = np.linalg.norm(rhs, 2)
88+
# make almost-zero norms be 1 (so we don't get divide-by-zero errors)
89+
rhs_is_zero = rhs_norm < eps
90+
rhs_norm = masked_fill(rhs_norm, mask=rhs_is_zero, fill_value=1)
91+
92+
# lets normalize rhs
93+
rhs = rhs / rhs_norm
94+
95+
# residuals
96+
residual = rhs - mat @ initial_guess
97+
batch_shape = residual.shape[:-2]
98+
99+
result = np.copy(initial_guess)
100+
101+
if not np.allclose(residual, residual):
102+
raise RuntimeError("NaNs encountered when trying to perform matrix-vector multiplication")
103+
104+
# sometimes we are lucky and preconditioner solves the system right away
105+
# check for convergence
106+
residual_norm = np.linalg.norm(residual, 2)
107+
has_converged = residual_norm < stop_updating_after
108+
109+
if has_converged.all() and not n_tridiag:
110+
n_iter = 0 # skip iterations
111+
else:
112+
precond_residual = preconditioner(residual)
113+
114+
curr_conjugate_vec = precond_residual
115+
residual_inner_prod = residual.T @ precond_residual
116+
117+
# define storage matrices
118+
mul_storage = np.zeros_like(residual)
119+
alpha = np.zeros((*batch_shape, 1, rhs.shape[-1]))
120+
beta = np.zeros_like(alpha)
121+
is_zero = np.zeros((*batch_shape, 1, rhs.shape[-1]))
122+
123+
# Define tridiagonal matrices if applicable
124+
if n_tridiag:
125+
t_mat = np.zeros((n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag))
126+
alpha_tridiag_is_zero = np.zeros(*batch_shape, n_tridiag)
127+
alpha_reciprocal = np.zeros(*batch_shape, n_tridiag)
128+
prev_alpha_reciprocal = np.zeros_like(alpha_reciprocal)
129+
prev_beta = np.zeros_like(alpha_reciprocal)
130+
131+
update_tridiag = True
132+
last_tridiag_iter = 0
133+
134+
# it is possible that we don't reach tolerance even after all the iterations are over
135+
tolerance_reached = False
136+
137+
# start iteration
138+
for k in range(n_iter):
139+
mvms = mat @ curr_conjugate_vec
140+
if precond:
141+
alpha = curr_conjugate_vec @ mvms # scalar
142+
143+
# safe division
144+
is_zero = alpha < eps
145+
alpha = masked_fill(alpha, mask=is_zero, fill_value=1)
146+
alpha = residual_inner_prod / alpha
147+
alpha = masked_fill(alpha, mask=is_zero, fill_value=0)
148+
149+
# cancel out updates by setting directions which have converged to zero
150+
alpha = masked_fill(alpha, mask=has_converged, fill_value=0)
151+
residual = residual - alpha * mvms
152+
153+
# update precond_residual
154+
precond_residual = preconditioner(residual)
155+
156+
# Everything inside _jit_linear_cg_updates
157+
(
158+
result,
159+
alpha,
160+
residual_inner_prod,
161+
eps,
162+
beta,
163+
residual,
164+
precond_residual,
165+
curr_conjugate_vec,
166+
) = linear_cg_updates(
167+
result,
168+
alpha,
169+
residual_inner_prod,
170+
eps,
171+
beta,
172+
residual,
173+
precond_residual,
174+
curr_conjugate_vec,
175+
)
176+
177+
else:
178+
# everything inside _jit_linear_cg_updates_no_precond
179+
alpha = curr_conjugate_vec.T @ mvms
180+
181+
# safe division
182+
is_zero = alpha < eps
183+
alpha = masked_fill(alpha, mask=is_zero, fill_value=1)
184+
alpha = residual_inner_prod / alpha
185+
alpha = masked_fill(alpha, is_zero, fill_value=0)
186+
187+
alpha = masked_fill(alpha, has_converged, fill_value=0) # <- I'm here
188+
residual = residual - alpha * mvms
189+
precond_residual = np.copy(residual)
190+
191+
(
192+
result,
193+
alpha,
194+
residual_inner_prod,
195+
eps,
196+
beta,
197+
residual,
198+
precond_residual,
199+
curr_conjugate_vec,
200+
) = linear_cg_updates(
201+
result,
202+
alpha,
203+
residual_inner_prod,
204+
eps,
205+
beta,
206+
residual,
207+
precond_residual,
208+
curr_conjugate_vec,
209+
)
210+
211+
residual_norm = np.linalg.norm(residual, 2)
212+
residual_norm = masked_fill(residual_norm, mask=rhs_is_zero, fill_value=0)
213+
has_converged = residual_norm < stop_updating_after
214+
215+
if (
216+
k >= min(10, max_iter - 1)
217+
and bool(residual_norm.mean() < tolerance)
218+
and not (n_tridiag and k < min(n_tridiag_iter, max_iter - 1))
219+
):
220+
tolerance_reached = True
221+
break
222+
223+
# Update tridiagonal matrices, if applicable
224+
if n_tridiag and k < n_tridiag_iter and update_tridiag:
225+
alpha_tridiag = np.copy(alpha)
226+
beta_tridiag = np.copy(beta)
227+
228+
alpha_tridiag_is_zero = alpha_tridiag == 0
229+
alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=1)
230+
alpha_reciprocal = 1 / alpha_tridiag
231+
alpha_tridiag = masked_fill(alpha_tridiag, mask=alpha_tridiag_is_zero, fill_value=0)
232+
233+
if k == 0:
234+
t_mat[k, k] = alpha_reciprocal
235+
else:
236+
t_mat[k, k] += np.squeeze(alpha_reciprocal + prev_beta * prev_alpha_reciprocal)
237+
t_mat[k, k - 1] = np.sqrt(prev_beta) * prev_alpha_reciprocal
238+
t_mat[k - 1, k] = np.copy(t_mat[k, k - 1])
239+
240+
if t_mat[k - 1, k].max() < 1e-6:
241+
update_tridiag = False
242+
243+
last_tridiag_iter = k
244+
245+
prev_alpha_reciprocal = np.copy(alpha_reciprocal)
246+
prev_beta = np.copy(beta_tridiag)
247+
248+
# Un-normalize
249+
result = result * rhs_norm
250+
if not tolerance_reached and n_iter > 0:
251+
raise RuntimeError(
252+
"CG terminated in {} iterations with average residual norm {}"
253+
" which is larger than the tolerance of {} specified by"
254+
" gpytorch.settings.cg_tolerance."
255+
" If performance is affected, consider raising the maximum number of CG iterations by running code in"
256+
" a gpytorch.settings.max_cg_iterations(value) context.".format(
257+
k + 1, residual_norm.mean(), tolerance
258+
)
259+
)
260+
261+
if n_tridiag:
262+
t_mat = t_mat[: last_tridiag_iter + 1, : last_tridiag_iter + 1]
263+
return result, t_mat.transpose(-1, *range(2, 2 + len(batch_shape)), 0, 1)
264+
else:
265+
# We set the estimated Lanczos tri-diagonal matrices to be identity so that
266+
# the subsequent eigen decomposition https://arxiv.org/pdf/1809.11165.pdf (eq.S7)
267+
# would work fine.
268+
# t_mat = np.zeros((n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag))
269+
# Note that after transpose the last two dimensions are dimensions 0 and 1 of the matrix above
270+
# Which are the same values i.e. n_tridiag_iter
271+
# So we generate identity matrices of size n_tridiag_iter and repeat them [n_iter, *range(2, 2+len(batch_shape))] times
272+
# TODO: for same input, n_tridiag = True and n_tridiag = False must produce t_mat with same shape (with assumed n_tridiag=1)
273+
n_tridiag = 1
274+
eye = np.eye(n_tridiag_iter)
275+
t_mat_eye = np.tile(eye, [n_tridiag] + [1] * (len(batch_shape) + 2))
276+
return result, t_mat_eye

0 commit comments

Comments
 (0)