Skip to content

Commit 0f8d876

Browse files
committed
Make BLAS flags check lazy and more actionable
It replaces the old warning that does not actually apply by a more informative and actionable one. This warning was for Ops that might use the alternative blas_headers, which rely on the Numpy C-API. However, regular PyTensor user has not used this for a while. The only Op that would use C-code with this alternative headers is the GEMM Op which is not included in current rewrites. Instead Dot22 or Dot22Scalar are introduced, which refuse to generate C-code altogether if the blas flags are missing.
1 parent 5fb56ba commit 0f8d876

File tree

3 files changed

+101
-35
lines changed

3 files changed

+101
-35
lines changed

doc/troubleshooting.rst

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -145,44 +145,64 @@ How do I configure/test my BLAS library
145145
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
146146

147147
There are many ways to configure BLAS for PyTensor. This is done with the PyTensor
148-
flags ``blas__ldflags`` (:ref:`libdoc_config`). The default is to use the BLAS
149-
installation information in NumPy, accessible via
150-
``numpy.__config__.show()``. You can tell pytensor to use a different
151-
version of BLAS, in case you did not compile NumPy with a fast BLAS or if NumPy
152-
was compiled with a static library of BLAS (the latter is not supported in
153-
PyTensor).
148+
flags ``blas__ldflags`` (:ref:`libdoc_config`). If not specified, PyTensor will
149+
attempt to find a local BLAS library to link against, prioritizing specialized implementations.
150+
The details can be found in :func:`pytensor.link.c.cmodule.default_blas_ldflags`.
154151

155-
The short way to configure the PyTensor flags ``blas__ldflags`` is by setting the
156-
environment variable :envvar:`PYTENSOR_FLAGS` to ``blas__ldflags=XXX`` (in bash
157-
``export PYTENSOR_FLAGS=blas__ldflags=XXX``)
152+
Users can manually set the PyTensor flags ``blas__ldflags`` to link against a
153+
specific version. This is useful even if the default version is the desired one,
154+
as it will avoid the costly work of trying to find the best BLAS library at runtime.
158155

159-
The ``${HOME}/.pytensorrc`` file is the simplest way to set a relatively
160-
permanent option like this one. Add a ``[blas]`` section with an ``ldflags``
161-
entry like this:
156+
The PyTensor flags can be set in a few ways:
157+
158+
1. In the ``${HOME}/.pytensorrc`` file.
162159

163160
.. code-block:: cfg
164161
165162
# other stuff can go here
166163
[blas]
167-
ldflags = -lf77blas -latlas -lgfortran #put your flags here
164+
ldflags = -llapack -lblas -lcblas # put your flags here
168165
169166
# other stuff can go here
170167
171-
For more information on the formatting of ``~/.pytensorrc`` and the
172-
configuration options that you can put there, see :ref:`libdoc_config`.
168+
2. In BASH before running your script:
169+
170+
.. code-block:: bash
171+
172+
export PYTENSOR_FLAGS="blas__ldflags='-llapack -lblas -lcblas'"
173+
174+
3. In an Ipython/Jupyter notebook before importing PyTensor:
175+
176+
.. code-block:: python
177+
178+
%set_env PYTENSOR_FLAGS=blas__ldflags='-llapack -lblas -lcblas'
179+
180+
181+
4. In `pytensor.config` directly:
182+
183+
.. code-block:: python
184+
185+
import pytensor
186+
pytensor.config.blas__ldflags = '-llapack -lblas -lcblas'
187+
188+
189+
(For more information on the formatting of ``~/.pytensorrc`` and the
190+
configuration options that you can put there, see :ref:`libdoc_config`.)
191+
192+
You can find the default BLAS library that PyTensor is linking against by
193+
checking ``pytensor.config.blas__ldflags``
194+
or running :func:`pytensor.link.c.cmodule.default_blas_ldflags`.
173195

174196
Here are some different way to configure BLAS:
175197

176-
0) Do nothing and use the default config, which is to link against the same
177-
BLAS against which NumPy was built. This does not work in the case NumPy was
178-
compiled with a static library (e.g. ATLAS is compiled by default only as a
179-
static library).
198+
0) Do nothing and use the default config.
199+
This will usually work great for installation via conda/mamba/pixi (conda-forge channel).
200+
It will usually fail to link altogether for installation via pip.
180201

181202
1) Disable the usage of BLAS and fall back on NumPy for dot products. To do
182-
this, set the value of ``blas__ldflags`` as the empty string (ex: ``export
183-
PYTENSOR_FLAGS=blas__ldflags=``). Depending on the kind of matrix operations your
184-
PyTensor code performs, this might slow some things down (vs. linking with BLAS
185-
directly).
203+
this, set the value of ``blas__ldflags`` as the empty string.
204+
Depending on the kind of matrix operations your PyTensor code performs,
205+
this might slow some things down (vs. linking with BLAS directly).
186206

187207
2) You can install the default (reference) version of BLAS if the NumPy version
188208
(against which PyTensor links) does not work. If you have root or sudo access in
@@ -208,10 +228,29 @@ correctly (for example, for MKL this might be ``-lmkl -lguide -lpthread`` or
208228
``-lmkl_intel_lp64 -lmkl_intel_thread -lmkl_core -lguide -liomp5 -lmkl_mc
209229
-lpthread``).
210230

231+
5) Use another backend such as Numba or JAX that perform their own BLAS optimizations,
232+
by setting the configuration mode to ``"NUMBA"`` or ``"JAX"`` and making sure those packages are installed.
233+
This configuration mode can be set in all the ways that the BLAS flags can be set, described above.
234+
235+
Alternatively, you can pass `mode='NUMBA'` when compiling individual PyTensor functions without changing the default.
236+
or use the ``config.change_flags`` context manager.
237+
238+
.. code-block:: python
239+
240+
from pytensor import function, config
241+
from pytensor.tensor import matrix
242+
243+
x = matrix('x')
244+
y = x @ x.T
245+
f = function([x], y, mode='NUMBA')
246+
247+
with config.change_flags(mode='NUMBA'):
248+
# compiling function that benefits from BLAS using NUMBA
249+
f = function([x], y)
250+
211251
.. note::
212252

213-
Make sure your BLAS
214-
libraries are available as dynamically-loadable libraries.
253+
Make sure your BLAS libraries are available as dynamically-loadable libraries.
215254
ATLAS is often installed only as a static library. PyTensor is not able to
216255
use this static library. Your ATLAS installation might need to be modified
217256
to provide dynamically loadable libraries. (On Linux this
@@ -267,7 +306,7 @@ configuration information. Then, it will print the running time of the same
267306
benchmarks for your installation. Try to find a CPU similar to yours in
268307
the table, and check that the single-threaded timings are roughly the same.
269308

270-
PyTensor should link to a parallel version of Blas and use all cores
309+
PyTensor should link to a parallel version of BLAS and use all cores
271310
when possible. By default it should use all cores. Set the environment
272311
variable "OMP_NUM_THREADS=N" to specify to use N threads.
273312

pytensor/link/c/cmodule.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,7 +1985,7 @@ def _try_flags(
19851985
)
19861986

19871987

1988-
def try_blas_flag(flags):
1988+
def try_blas_flag(flags) -> str:
19891989
test_code = textwrap.dedent(
19901990
"""\
19911991
extern "C" double ddot_(int*, double*, int*, double*, int*);
@@ -2734,12 +2734,30 @@ def check_mkl_openmp():
27342734
)
27352735

27362736

2737-
def default_blas_ldflags():
2738-
"""Read local NumPy and MKL build settings and construct `ld` flags from them.
2737+
def default_blas_ldflags() -> str:
2738+
"""Look for an available BLAS implementation in the system.
2739+
2740+
This function tries to compile a simple C code that uses the BLAS
2741+
if the required files are found in the system.
2742+
It sequentially tries to link to the following implementations, until one is found:
2743+
1. Intel MKL with Intel OpenMP threading
2744+
2. Intel MKL with GNU OpenMP threading
2745+
3. Lapack + BLAS
2746+
4. BLAS alone
2747+
5. OpenBLAS
27392748
27402749
Returns
27412750
-------
2742-
str
2751+
blas flags: str
2752+
Blas flags needed to link to the BLAS implementation found in the system.
2753+
If no BLAS implementation is found, an empty string is returned.
2754+
2755+
Notes
2756+
-----
2757+
This function is triggered when `pytensor.config.blas__ldflags` is not given a user
2758+
default, and it is first accessed at runtime. It can be rather slow, so it is advised
2759+
to cache the results of this function in PYTENSORRC configuration file or
2760+
PyTensor environment flags.
27432761
27442762
"""
27452763

@@ -2788,7 +2806,7 @@ def get_cxx_library_dirs():
27882806

27892807
def check_libs(
27902808
all_libs, required_libs, extra_compile_flags=None, cxx_library_dirs=None
2791-
):
2809+
) -> str:
27922810
if cxx_library_dirs is None:
27932811
cxx_library_dirs = []
27942812
if extra_compile_flags is None:
@@ -2938,6 +2956,14 @@ def check_libs(
29382956
except Exception as e:
29392957
_logger.debug(e)
29402958
_logger.debug("Failed to identify blas ldflags. Will leave them empty.")
2959+
warnings.warn(
2960+
"PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.\n"
2961+
"This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.\n"
2962+
"Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, "
2963+
"by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.\n"
2964+
"For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library",
2965+
UserWarning,
2966+
)
29412967
return ""
29422968

29432969

pytensor/tensor/blas_headers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,11 @@ def blas_header_text():
742742

743743
blas_code = ""
744744
if not config.blas__ldflags:
745+
# This code can only be reached by compiling a function with a manually specified GEMM Op.
746+
# Normal PyTensor usage will end up with Dot22 or Dot22Scalar instead,
747+
# which opt out of C-code completely if the blas flags are missing
748+
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")
749+
745750
# Include the Numpy version implementation of [sd]gemm_.
746751
current_filedir = Path(__file__).parent
747752
blas_common_filepath = current_filedir / "c_code/alt_blas_common.h"
@@ -1003,10 +1008,6 @@ def blas_header_text():
10031008
return header + blas_code
10041009

10051010

1006-
if not config.blas__ldflags:
1007-
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")
1008-
1009-
10101011
def mkl_threads_text():
10111012
"""C header for MKL threads interface"""
10121013
header = """

0 commit comments

Comments
 (0)