|
1 |
| -from pytensor.configdefaults import config |
2 |
| -from pytensor.graph.rewriting.basic import in2out |
3 | 1 | from pytensor.link.c.op import COp
|
4 | 2 | from pytensor.link.c.params_type import ParamsType
|
5 | 3 | from pytensor.scalar import bool as bool_t
|
6 |
| -from pytensor.tensor import basic as at |
7 | 4 | from pytensor.tensor.blas import (
|
8 | 5 | Gemv,
|
9 | 6 | Ger,
|
10 | 7 | blas_header_text,
|
11 | 8 | blas_header_version,
|
12 |
| - blas_optdb, |
13 |
| - gemv_inplace, |
14 |
| - gemv_no_inplace, |
15 |
| - ger, |
16 |
| - ger_destructive, |
17 | 9 | ldflags,
|
18 |
| - node_rewriter, |
19 |
| - optdb, |
20 | 10 | )
|
21 | 11 |
|
22 | 12 |
|
@@ -344,23 +334,6 @@ def c_code_cache_version(self):
|
344 | 334 | cger_no_inplace = CGer(False)
|
345 | 335 |
|
346 | 336 |
|
347 |
| -@node_rewriter([ger, ger_destructive]) |
348 |
| -def use_c_ger(fgraph, node): |
349 |
| - if not config.blas__ldflags: |
350 |
| - return |
351 |
| - # Only float32 and float64 are supported for now. |
352 |
| - if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): |
353 |
| - return [CGer(False)(*node.inputs)] |
354 |
| - if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): |
355 |
| - return [CGer(True)(*node.inputs)] |
356 |
| - |
357 |
| - |
358 |
| -@node_rewriter([CGer(False)]) |
359 |
| -def make_c_ger_destructive(fgraph, node): |
360 |
| - if isinstance(node.op, CGer) and not node.op.destructive: |
361 |
| - return [cger_inplace(*node.inputs)] |
362 |
| - |
363 |
| - |
364 | 337 | # ##### ####### #######
|
365 | 338 | # GEMV
|
366 | 339 | # ##### ####### #######
|
@@ -697,48 +670,3 @@ def check_force_gemv_init():
|
697 | 670 |
|
698 | 671 |
|
699 | 672 | check_force_gemv_init._force_init_beta = None
|
700 |
| - |
701 |
| - |
702 |
| -@node_rewriter([gemv_inplace, gemv_no_inplace]) |
703 |
| -def use_c_gemv(fgraph, node): |
704 |
| - if not config.blas__ldflags: |
705 |
| - return |
706 |
| - # Only float32 and float64 are supported for now. |
707 |
| - if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): |
708 |
| - return [cgemv_no_inplace(*node.inputs)] |
709 |
| - if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): |
710 |
| - return [cgemv_inplace(*node.inputs)] |
711 |
| - |
712 |
| - |
713 |
| -@node_rewriter([CGemv(inplace=False)]) |
714 |
| -def make_c_gemv_destructive(fgraph, node): |
715 |
| - if isinstance(node.op, CGemv) and not node.op.inplace: |
716 |
| - inputs = list(node.inputs) |
717 |
| - dest = inputs[0] |
718 |
| - if ( |
719 |
| - dest.owner |
720 |
| - and isinstance(dest.owner.op, at.AllocEmpty) |
721 |
| - and len(fgraph.clients[dest]) > 1 |
722 |
| - ): |
723 |
| - inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) |
724 |
| - |
725 |
| - return [cgemv_inplace(*inputs)] |
726 |
| - |
727 |
| - |
728 |
| -# ##### ####### ####### |
729 |
| -# Optimizers |
730 |
| -# ##### ####### ####### |
731 |
| - |
732 |
| -blas_optdb.register( |
733 |
| - "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 |
734 |
| -) |
735 |
| - |
736 |
| -# this matches the InplaceBlasOpt defined in blas.py |
737 |
| -optdb.register( |
738 |
| - "c_blas_destructive", |
739 |
| - in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), |
740 |
| - "fast_run", |
741 |
| - "inplace", |
742 |
| - "c_blas", |
743 |
| - position=70.0, |
744 |
| -) |
0 commit comments