@@ -165,13 +165,22 @@ def test_flag_detection():
165
165
166
166
@pytest .fixture (
167
167
scope = "module" ,
168
- params = ["mkl_intel" , "mkl_gnu" , "openblas" , "lapack" , "blas" , "no_blas" ],
168
+ params = [
169
+ "mkl_intel" ,
170
+ "mkl_gnu" ,
171
+ "accelerate" ,
172
+ "openblas" ,
173
+ "lapack" ,
174
+ "blas" ,
175
+ "no_blas" ,
176
+ ],
169
177
)
170
178
def blas_libs (request ):
171
179
key = request .param
172
180
libs = {
173
181
"mkl_intel" : ["mkl_core" , "mkl_rt" , "mkl_intel_thread" , "iomp5" , "pthread" ],
174
182
"mkl_gnu" : ["mkl_core" , "mkl_rt" , "mkl_gnu_thread" , "gomp" , "pthread" ],
183
+ "accelerate" : ["vecLib_placeholder" ],
175
184
"openblas" : ["openblas" , "gfortran" , "gomp" , "m" ],
176
185
"lapack" : ["lapack" , "blas" , "cblas" , "m" ],
177
186
"blas" : ["blas" , "cblas" ],
@@ -190,25 +199,37 @@ def mock_system(request):
190
199
def cxx_search_dirs (blas_libs , mock_system ):
191
200
libext = {"Linux" : "so" , "Windows" : "dll" , "Darwin" : "dylib" }
192
201
libraries = []
202
+ enabled_accelerate_framework = False
193
203
with tempfile .TemporaryDirectory () as d :
194
204
flags = None
195
205
for lib in blas_libs :
196
- lib_path = Path (d ) / f"{ lib } .{ libext [mock_system ]} "
197
- lib_path .write_bytes (b"1" )
198
- libraries .append (lib_path )
199
- if flags is None :
200
- flags = f"-l{ lib } "
206
+ if lib == "vecLib_placeholder" :
207
+ if mock_system != "Darwin" :
208
+ flags = ""
209
+ else :
210
+ flags = "-framework Accelerate"
211
+ enabled_accelerate_framework = True
201
212
else :
202
- flags += f" -l{ lib } "
213
+ lib_path = Path (d ) / f"{ lib } .{ libext [mock_system ]} "
214
+ lib_path .write_bytes (b"1" )
215
+ libraries .append (lib_path )
216
+ if flags is None :
217
+ flags = f"-l{ lib } "
218
+ else :
219
+ flags += f" -l{ lib } "
203
220
if "gomp" in blas_libs and "mkl_gnu_thread" not in blas_libs :
204
221
flags += " -fopenmp"
205
222
if len (blas_libs ) == 0 :
206
223
flags = ""
207
- yield f"libraries: ={ d } " .encode (sys .stdout .encoding ), flags
224
+ yield (
225
+ f"libraries: ={ d } " .encode (sys .stdout .encoding ),
226
+ flags ,
227
+ enabled_accelerate_framework ,
228
+ )
208
229
209
230
210
231
@pytest .fixture (
211
- scope = "function" , params = [False , True ], ids = ["Working_CXX" , "Broken_CXX" ]
232
+ scope = "function" , params = [True , False ], ids = ["Working_CXX" , "Broken_CXX" ]
212
233
)
213
234
def cxx_search_dirs_status (request ):
214
235
return request .param
@@ -219,22 +240,39 @@ def cxx_search_dirs_status(request):
219
240
def test_default_blas_ldflags (
220
241
mock_std_lib_dirs , mock_check_mkl_openmp , cxx_search_dirs , cxx_search_dirs_status
221
242
):
222
- cxx_search_dirs , expected_blas_ldflags = cxx_search_dirs
243
+ cxx_search_dirs , expected_blas_ldflags , enabled_accelerate_framework = (
244
+ cxx_search_dirs
245
+ )
223
246
mock_process = MagicMock ()
224
247
if cxx_search_dirs_status :
225
248
error_message = ""
226
249
mock_process .communicate = lambda * args , ** kwargs : (cxx_search_dirs , b"" )
227
250
mock_process .returncode = 0
228
251
else :
252
+ enabled_accelerate_framework = False
229
253
error_message = "Unsupported argument -print-search-dirs"
230
254
error_message_bytes = error_message .encode (sys .stderr .encoding )
231
255
mock_process .communicate = lambda * args , ** kwargs : (b"" , error_message_bytes )
232
256
mock_process .returncode = 1
257
+
258
+ def patched_compile_tmp (* args , ** kwargs ):
259
+ def wrapped (test_code , tmp_prefix , flags , try_run , output ):
260
+ if len (flags ) >= 2 and flags [:2 ] == ["-framework" , "Accelerate" ]:
261
+ print (enabled_accelerate_framework )
262
+ if enabled_accelerate_framework :
263
+ return (True , True )
264
+ else :
265
+ return (False , False , "" , "Invalid flags -framework Accelerate" )
266
+ else :
267
+ return (True , True )
268
+
269
+ return wrapped
270
+
233
271
with patch ("pytensor.link.c.cmodule.subprocess_Popen" , return_value = mock_process ):
234
272
with patch .object (
235
273
pytensor .link .c .cmodule .GCC_compiler ,
236
274
"try_compile_tmp" ,
237
- return_value = ( True , True ) ,
275
+ new_callable = patched_compile_tmp ,
238
276
):
239
277
if cxx_search_dirs_status :
240
278
assert set (default_blas_ldflags ().split (" " )) == set (
@@ -267,6 +305,9 @@ def windows_conda_libs(blas_libs):
267
305
subdir .mkdir (exist_ok = True , parents = True )
268
306
flags = f'-L"{ subdir } "'
269
307
for lib in blas_libs :
308
+ if lib == "vecLib_placeholder" :
309
+ flags = ""
310
+ break
270
311
lib_path = subdir / f"{ lib } .dll"
271
312
lib_path .write_bytes (b"1" )
272
313
libraries .append (lib_path )
@@ -287,6 +328,16 @@ def test_default_blas_ldflags_conda_windows(
287
328
mock_process = MagicMock ()
288
329
mock_process .communicate = lambda * args , ** kwargs : (b"" , b"" )
289
330
mock_process .returncode = 0
331
+
332
+ def patched_compile_tmp (* args , ** kwargs ):
333
+ def wrapped (test_code , tmp_prefix , flags , try_run , output ):
334
+ if len (flags ) >= 2 and flags [:2 ] == ["-framework" , "Accelerate" ]:
335
+ return (False , False , "" , "Invalid flags -framework Accelerate" )
336
+ else :
337
+ return (True , True )
338
+
339
+ return wrapped
340
+
290
341
with patch ("sys.platform" , "win32" ):
291
342
with patch ("sys.prefix" , mock_sys_prefix ):
292
343
with patch (
@@ -295,7 +346,7 @@ def test_default_blas_ldflags_conda_windows(
295
346
with patch .object (
296
347
pytensor .link .c .cmodule .GCC_compiler ,
297
348
"try_compile_tmp" ,
298
- return_value = ( True , True ) ,
349
+ new_callable = patched_compile_tmp ,
299
350
):
300
351
assert set (default_blas_ldflags ().split (" " )) == set (
301
352
expected_blas_ldflags .split (" " )
0 commit comments