Description
Describe the issue:
Matrix multiplication can result in EOFError
on M1 chipsets. The following example allows to see the issue.
Reproducable code example:
import numpy as np
import pandas as pd
import pymc as pm
from formulae import design_matrices
# 3 columns of 100k random-normal numbers
df = pd.DataFrame(np.random.randn(300000).reshape((-1, 3)), columns=['x1', 'x2', 'y'])
# create design matrix
dm = design_matrices('y ~ 1 + x1 + x2', df)
# fit model with multiplication + summation (this works!)
with pm.Model():
b = pm.Normal('b', mu=0, sigma=1, shape=3)
e = pm.HalfStudentT('e', sigma=1, nu=3)
yhat = (dm.common.design_matrix * b).sum(axis=1)
likelihood = pm.Normal('likelihood', mu=yhat, sigma=e, observed=df.y)
trace = pm.sample()
# fit model with pm.math.dot (this does not work!)
with pm.Model():
b = pm.Normal('b', mu=0, sigma=1, shape=3)
e = pm.HalfStudentT('e', sigma=1, nu=3)
yhat = pm.math.dot(dm.common.design_matrix, b)
likelihood = pm.Normal('likelihood', mu=yhat, sigma=e, observed=df.y)
trace = pm.sample()
Error message:
---------------------------------------------------------------------------
EOFError Traceback (most recent call last)
Cell In[14], line 10
8 df = pd.DataFrame(np.vstack([S, Q, X, Y, P]).T, columns=["S", "Q", "X", "Y", "P"])
9 model = bmb.Model('Y ~ S + Q + X + P', df)
---> 10 results = model.fit()
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/bambi/models.py:325, in Model.fit(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
318 response = self.components[self.response_name]
319 _log.info(
320 "Modeling the probability that %s==%s",
321 response.response_term.name,
322 str(response.response_term.success),
323 )
--> 325 return self.backend.run(
326 draws=draws,
327 tune=tune,
328 discard_tuned_samples=discard_tuned_samples,
329 omit_offsets=omit_offsets,
330 include_mean=include_mean,
331 inference_method=inference_method,
332 init=init,
333 n_init=n_init,
334 chains=chains,
335 cores=cores,
336 random_seed=random_seed,
337 **kwargs,
338 )
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/bambi/backend/pymc.py:96, in PyMCModel.run(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, inference_method, init, n_init, chains, cores, random_seed, **kwargs)
94 # NOTE: Methods return different types of objects (idata, approximation, and dictionary)
95 if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
---> 96 result = self._run_mcmc(
97 draws,
98 tune,
99 discard_tuned_samples,
100 omit_offsets,
101 include_mean,
102 init,
103 n_init,
104 chains,
105 cores,
106 random_seed,
107 inference_method,
108 **kwargs,
109 )
110 elif inference_method == "vi":
111 result = self._run_vi(**kwargs)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/bambi/backend/pymc.py:172, in PyMCModel._run_mcmc(self, draws, tune, discard_tuned_samples, omit_offsets, include_mean, init, n_init, chains, cores, random_seed, sampler_backend, **kwargs)
170 if sampler_backend == "mcmc":
171 try:
--> 172 idata = pm.sample(
173 draws=draws,
174 tune=tune,
175 discard_tuned_samples=discard_tuned_samples,
176 init=init,
177 n_init=n_init,
178 chains=chains,
179 cores=cores,
180 random_seed=random_seed,
181 **kwargs,
182 )
183 except (RuntimeError, ValueError):
184 if (
185 "ValueError: Mass matrix contains" in traceback.format_exc()
186 and init == "auto"
187 ):
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:766, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
764 _print_step_hierarchy(step)
765 try:
--> 766 _mp_sample(**sample_args, **parallel_args)
767 except pickle.PickleError:
768 _log.warning("Could not pickle model, sampling singlethreaded.")
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/mcmc.py:1155, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, traces, model, callback, mp_ctx, **kwargs)
1153 try:
1154 with sampler:
-> 1155 for draw in sampler:
1156 strace = traces[draw.chain]
1157 strace.record(draw.point, draw.stats)
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/parallel.py:448, in ParallelSampler.__iter__(self)
445 self._progress.update(self._total_draws)
447 while self._active:
--> 448 draw = ProcessAdapter.recv_draw(self._active)
449 proc, is_last, draw, tuning, stats = draw
450 self._total_draws += 1
File ~/mambaforge/envs/pymc_env/lib/python3.11/site-packages/pymc/sampling/parallel.py:320, in ProcessAdapter.recv_draw(processes, timeout)
318 idxs = {id(proc._msg_pipe): proc for proc in processes}
319 proc = idxs[id(ready[0])]
--> 320 msg = ready[0].recv()
322 if msg[0] == "error":
323 old_error = msg[1]
File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/connection.py:249, in _ConnectionBase.recv(self)
247 self._check_closed()
248 self._check_readable()
--> 249 buf = self._recv_bytes()
250 return _ForkingPickler.loads(buf.getbuffer())
File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/connection.py:413, in Connection._recv_bytes(self, maxsize)
412 def _recv_bytes(self, maxsize=None):
--> 413 buf = self._recv(4)
414 size, = struct.unpack("!i", buf.getvalue())
415 if size == -1:
File ~/mambaforge/envs/pymc_env/lib/python3.11/multiprocessing/connection.py:382, in Connection._recv(self, size, read)
380 if n == 0:
381 if remaining == size:
--> 382 raise EOFError
383 else:
384 raise OSError("got end of file during message")
EOFError:
PyTensor version information:
Context for the issue:
Creating a design matrix with formulae.design_matrices()
isn't the thing that breaks under mp_ctx='fork'
, but computing the dot product using pm.math.dot()
does. Computing the dot product using element-wise multiplication and summation works just fine, presumably because that compiles down differently than pm.math.dot()
.
I can still get the model to sample by reducing the size of the dataset, so it's not like pm.math.dot()
doesn't ever run on M1 chipsets at all, but it does seem like it's not thread/fork-safe and should be used with the forkserver
multiprocessing context instead.
Edited to add: The dot-product that pm.math.dot()
has to compute here is dense, so I think it ends up calling pytensor.tensor.math.dense_dot()
which I think calls an instance of pytensor.tensor.math.Dot
. I'm not sure what it is in there that isn't thread/fork-safe, maybe memory allocation for calls to BLAS or something like that?
Originally posted by @jvparidon in bambinos/bambi#700 (comment)