Skip to content

Commit b4e9531

Browse files
committed
Add Accelerate framework blas__ldflags tests
1 parent e73258b commit b4e9531

File tree

3 files changed

+127
-20
lines changed

3 files changed

+127
-20
lines changed

pytensor/link/c/cmodule.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,14 +2458,32 @@ def patch_ldflags(flag_list: list[str]) -> list[str]:
24582458
@staticmethod
24592459
def linking_patch(lib_dirs: list[str], libs: list[str]) -> list[str]:
24602460
if sys.platform != "win32":
2461-
return [f"-l{l}" for l in libs]
2461+
patched_libs = []
2462+
framework = False
2463+
for lib in libs:
2464+
# The clang framework flag is handled differently.
2465+
# The flag will have the format -framework framework_name
2466+
# If we find a lib that is called -framework, we keep it and the following
2467+
# entry in the lib list unchanged. Anything else, we add the standard
2468+
# -l library prefix.
2469+
if lib == "-framework":
2470+
framework = True
2471+
patched_libs.append(lib)
2472+
elif framework:
2473+
framework = False
2474+
patched_libs.append(lib)
2475+
else:
2476+
patched_libs.append(f"-l{lib}")
2477+
return patched_libs
24622478
else:
24632479
# In explicit else because of https://github.com/python/mypy/issues/10773
24642480
def sort_key(lib):
24652481
name, *numbers, extension = lib.split(".")
24662482
return (extension == "dll", tuple(map(int, numbers)))
24672483

24682484
patched_lib_ldflags = []
2485+
# Should we also add a framework possibility on windows? I didn't do so because
2486+
# clang is not intended to be used there at the moment.
24692487
for lib in libs:
24702488
ldflag = f"-l{lib}"
24712489
for lib_dir in lib_dirs:
@@ -2873,9 +2891,21 @@ def check_libs(
28732891
)
28742892
except Exception as e:
28752893
_logger.debug(e)
2894+
try:
2895+
# 3. Mac Accelerate framework
2896+
_logger.debug("Checking Accelerate framework")
2897+
flags = ["-framework", "Accelerate"]
2898+
if rpath:
2899+
flags = [*flags, f"-Wl,-rpath,{rpath}"]
2900+
validated_flags = try_blas_flag(flags)
2901+
if validated_flags == "":
2902+
raise Exception("Accelerate framework flag failed ")
2903+
return validated_flags
2904+
except Exception as e:
2905+
_logger.debug(e)
28762906
try:
28772907
_logger.debug("Checking Lapack + blas")
2878-
# 3. Try to use LAPACK + BLAS
2908+
# 4. Try to use LAPACK + BLAS
28792909
return check_libs(
28802910
all_libs,
28812911
required_libs=["lapack", "blas", "cblas", "m"],
@@ -2885,7 +2915,7 @@ def check_libs(
28852915
except Exception as e:
28862916
_logger.debug(e)
28872917
try:
2888-
# 4. Try to use BLAS alone
2918+
# 5. Try to use BLAS alone
28892919
_logger.debug("Checking blas alone")
28902920
return check_libs(
28912921
all_libs,
@@ -2896,7 +2926,7 @@ def check_libs(
28962926
except Exception as e:
28972927
_logger.debug(e)
28982928
try:
2899-
# 5. Try to use openblas
2929+
# 6. Try to use openblas
29002930
_logger.debug("Checking openblas")
29012931
return check_libs(
29022932
all_libs,

pytensor/tensor/blas.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@
7878
import functools
7979
import logging
8080
import os
81+
import shlex
8182
import time
83+
from pathlib import Path
8284

8385
import numpy as np
8486

@@ -396,7 +398,7 @@ def _ldflags(
396398
rval = []
397399
if libs_dir:
398400
found_dyn = False
399-
dirs = [x[2:] for x in ldflags_str.split() if x.startswith("-L")]
401+
dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")]
400402
l = _ldflags(
401403
ldflags_str=ldflags_str,
402404
libs=True,
@@ -409,14 +411,22 @@ def _ldflags(
409411
if f.endswith(".so") or f.endswith(".dylib") or f.endswith(".dll"):
410412
if any(f.find(ll) >= 0 for ll in l):
411413
found_dyn = True
414+
# Special treatment of clang framework. Specifically for MacOS Accelerate
415+
if "-framework" in l and "Accelerate" in l:
416+
found_dyn = True
412417
if not found_dyn and dirs:
413418
_logger.warning(
414419
"We did not find a dynamic library in the "
415420
"library_dir of the library we use for blas. If you use "
416421
"ATLAS, make sure to compile it with dynamics library."
417422
)
418423

419-
for t in ldflags_str.split():
424+
split_flags = shlex.split(ldflags_str)
425+
skip = False
426+
for pos, t in enumerate(split_flags):
427+
if skip:
428+
skip = False
429+
continue
420430
# Remove extra quote.
421431
if (t.startswith("'") and t.endswith("'")) or (
422432
t.startswith('"') and t.endswith('"')
@@ -425,10 +435,26 @@ def _ldflags(
425435

426436
try:
427437
t0, t1 = t[0], t[1]
428-
assert t0 == "-"
438+
assert t0 == "-" or Path(t).exists()
429439
except Exception:
430440
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
431-
if libs_dir and t1 == "L":
441+
if t == "-framework":
442+
skip = True
443+
# Special treatment of clang framework. Specifically for MacOS Accelerate
444+
# The clang framework implicitly adds: header dirs, libraries, and library dirs.
445+
# If we choose to always return these flags, we run into a huge deal amount of
446+
# incompatibilities. For this reason, we only return the framework if libs are
447+
# requested.
448+
if (
449+
libs
450+
and len(split_flags) >= pos
451+
and split_flags[pos + 1] == "Accelerate"
452+
):
453+
# We only add the Accelerate framework, but in the future we could extend it to
454+
# other frameworks
455+
rval.append(t)
456+
rval.append(split_flags[pos + 1])
457+
elif libs_dir and t1 == "L":
432458
rval.append(t[2:])
433459
elif include_dir and t1 == "I":
434460
raise ValueError(

tests/link/c/test_cmodule.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,22 @@ def test_flag_detection():
165165

166166
@pytest.fixture(
167167
scope="module",
168-
params=["mkl_intel", "mkl_gnu", "openblas", "lapack", "blas", "no_blas"],
168+
params=[
169+
"mkl_intel",
170+
"mkl_gnu",
171+
"accelerate",
172+
"openblas",
173+
"lapack",
174+
"blas",
175+
"no_blas",
176+
],
169177
)
170178
def blas_libs(request):
171179
key = request.param
172180
libs = {
173181
"mkl_intel": ["mkl_core", "mkl_rt", "mkl_intel_thread", "iomp5", "pthread"],
174182
"mkl_gnu": ["mkl_core", "mkl_rt", "mkl_gnu_thread", "gomp", "pthread"],
183+
"accelerate": ["vecLib_placeholder"],
175184
"openblas": ["openblas", "gfortran", "gomp", "m"],
176185
"lapack": ["lapack", "blas", "cblas", "m"],
177186
"blas": ["blas", "cblas"],
@@ -190,25 +199,37 @@ def mock_system(request):
190199
def cxx_search_dirs(blas_libs, mock_system):
191200
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
192201
libraries = []
202+
enabled_accelerate_framework = False
193203
with tempfile.TemporaryDirectory() as d:
194204
flags = None
195205
for lib in blas_libs:
196-
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
197-
lib_path.write_bytes(b"1")
198-
libraries.append(lib_path)
199-
if flags is None:
200-
flags = f"-l{lib}"
206+
if lib == "vecLib_placeholder":
207+
if mock_system != "Darwin":
208+
flags = ""
209+
else:
210+
flags = "-framework Accelerate"
211+
enabled_accelerate_framework = True
201212
else:
202-
flags += f" -l{lib}"
213+
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
214+
lib_path.write_bytes(b"1")
215+
libraries.append(lib_path)
216+
if flags is None:
217+
flags = f"-l{lib}"
218+
else:
219+
flags += f" -l{lib}"
203220
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs:
204221
flags += " -fopenmp"
205222
if len(blas_libs) == 0:
206223
flags = ""
207-
yield f"libraries: ={d}".encode(sys.stdout.encoding), flags
224+
yield (
225+
f"libraries: ={d}".encode(sys.stdout.encoding),
226+
flags,
227+
enabled_accelerate_framework,
228+
)
208229

209230

210231
@pytest.fixture(
211-
scope="function", params=[False, True], ids=["Working_CXX", "Broken_CXX"]
232+
scope="function", params=[True, False], ids=["Working_CXX", "Broken_CXX"]
212233
)
213234
def cxx_search_dirs_status(request):
214235
return request.param
@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
219240
def test_default_blas_ldflags(
220241
mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status
221242
):
222-
cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs
243+
cxx_search_dirs, expected_blas_ldflags, enabled_accelerate_framework = (
244+
cxx_search_dirs
245+
)
223246
mock_process = MagicMock()
224247
if cxx_search_dirs_status:
225248
error_message = ""
226249
mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"")
227250
mock_process.returncode = 0
228251
else:
252+
enabled_accelerate_framework = False
229253
error_message = "Unsupported argument -print-search-dirs"
230254
error_message_bytes = error_message.encode(sys.stderr.encoding)
231255
mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes)
232256
mock_process.returncode = 1
257+
258+
def patched_compile_tmp(*args, **kwargs):
259+
def wrapped(test_code, tmp_prefix, flags, try_run, output):
260+
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
261+
print(enabled_accelerate_framework)
262+
if enabled_accelerate_framework:
263+
return (True, True)
264+
else:
265+
return (False, False, "", "Invalid flags -framework Accelerate")
266+
else:
267+
return (True, True)
268+
269+
return wrapped
270+
233271
with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process):
234272
with patch.object(
235273
pytensor.link.c.cmodule.GCC_compiler,
236274
"try_compile_tmp",
237-
return_value=(True, True),
275+
new_callable=patched_compile_tmp,
238276
):
239277
if cxx_search_dirs_status:
240278
assert set(default_blas_ldflags().split(" ")) == set(
@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
267305
subdir.mkdir(exist_ok=True, parents=True)
268306
flags = f'-L"{subdir}"'
269307
for lib in blas_libs:
308+
if lib == "vecLib_placeholder":
309+
flags = ""
310+
break
270311
lib_path = subdir / f"{lib}.dll"
271312
lib_path.write_bytes(b"1")
272313
libraries.append(lib_path)
@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
287328
mock_process = MagicMock()
288329
mock_process.communicate = lambda *args, **kwargs: (b"", b"")
289330
mock_process.returncode = 0
331+
332+
def patched_compile_tmp(*args, **kwargs):
333+
def wrapped(test_code, tmp_prefix, flags, try_run, output):
334+
if len(flags) >= 2 and flags[:2] == ["-framework", "Accelerate"]:
335+
return (False, False, "", "Invalid flags -framework Accelerate")
336+
else:
337+
return (True, True)
338+
339+
return wrapped
340+
290341
with patch("sys.platform", "win32"):
291342
with patch("sys.prefix", mock_sys_prefix):
292343
with patch(
@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
295346
with patch.object(
296347
pytensor.link.c.cmodule.GCC_compiler,
297348
"try_compile_tmp",
298-
return_value=(True, True),
349+
new_callable=patched_compile_tmp,
299350
):
300351
assert set(default_blas_ldflags().split(" ")) == set(
301352
expected_blas_ldflags.split(" ")

0 commit comments

Comments
 (0)