@@ -61,9 +61,8 @@ def register_linker(name, linker):
61
61
# If a string is passed as the optimizer argument in the constructor
62
62
# for Mode, it will be used as the key to retrieve the real optimizer
63
63
# in this dictionary
64
- exclude = []
65
- if not config .cxx :
66
- exclude = ["cxx_only" ]
64
+
65
+ exclude = ["cxx_only" , "BlasOpt" ]
67
66
OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
68
67
# Even if multiple merge optimizer call will be there, this shouldn't
69
68
# impact performance.
@@ -340,6 +339,11 @@ def __setstate__(self, state):
340
339
optimizer = predefined_optimizers [optimizer ]
341
340
if isinstance (optimizer , RewriteDatabaseQuery ):
342
341
self .provided_optimizer = optimizer
342
+
343
+ # Force numba-required rewrites if using NumbaLinker
344
+ if isinstance (linker , NumbaLinker ):
345
+ optimizer = optimizer .including ("numba" )
346
+
343
347
self ._optimizer = optimizer
344
348
self .call_time = 0
345
349
self .fn_time = 0
@@ -437,16 +441,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
437
441
# string as the key
438
442
# Use VM_linker to allow lazy evaluation by default.
439
443
FAST_COMPILE = Mode (
440
- VMLinker (use_cloop = False , c_thunks = False ),
441
- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
444
+ NumbaLinker (),
445
+ # TODO: Fast_compile should just use python code, CHANGE ME!
446
+ RewriteDatabaseQuery (
447
+ include = ["fast_compile" , "numba" ],
448
+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
449
+ ),
450
+ )
451
+ FAST_RUN = Mode (
452
+ NumbaLinker (),
453
+ RewriteDatabaseQuery (
454
+ include = ["fast_run" , "numba" ],
455
+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
456
+ ),
442
457
)
443
- if config .cxx :
444
- FAST_RUN = Mode ("cvm" , "fast_run" )
445
- else :
446
- FAST_RUN = Mode (
447
- "vm" ,
448
- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
449
- )
450
458
451
459
JAX = Mode (
452
460
JAXLinker (),
@@ -512,7 +520,7 @@ def get_mode(orig_string):
512
520
# NanGuardMode use its own linker.
513
521
ret = NanGuardMode (True , True , True , optimizer = config .optimizer )
514
522
else :
515
- # TODO: Can't we look up the name and invoke it rather than using eval here ?
523
+ # TODO: Get rid of this? Or refactor ?
516
524
ret = eval (string + "(linker=config.linker, optimizer=config.optimizer)" )
517
525
elif string in predefined_modes :
518
526
ret = predefined_modes [string ]
@@ -541,6 +549,7 @@ def register_mode(name, mode):
541
549
Add a `Mode` which can be referred to by `name` in `function`.
542
550
543
551
"""
552
+ # TODO: Remove me
544
553
if name in predefined_modes :
545
554
raise ValueError (f"Mode name already taken: { name } " )
546
555
predefined_modes [name ] = mode
0 commit comments