Skip to content

Commit 2b1367a

Browse files
committed
split blas Ops and rewrites
Having Ops and rewrites in the same files was causing circular imports.
1 parent 138daee commit 2b1367a

File tree

8 files changed

+1025
-671
lines changed

8 files changed

+1025
-671
lines changed

pytensor/tensor/blas.py

Lines changed: 6 additions & 554 deletions
Large diffs are not rendered by default.

pytensor/tensor/blas_c.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
1-
from pytensor.configdefaults import config
2-
from pytensor.graph.rewriting.basic import in2out
31
from pytensor.link.c.op import COp
42
from pytensor.link.c.params_type import ParamsType
53
from pytensor.scalar import bool as bool_t
6-
from pytensor.tensor import basic as at
74
from pytensor.tensor.blas import (
85
Gemv,
96
Ger,
107
blas_header_text,
118
blas_header_version,
12-
blas_optdb,
13-
gemv_inplace,
14-
gemv_no_inplace,
15-
ger,
16-
ger_destructive,
179
ldflags,
18-
node_rewriter,
19-
optdb,
2010
)
2111

2212

@@ -344,23 +334,6 @@ def c_code_cache_version(self):
344334
cger_no_inplace = CGer(False)
345335

346336

347-
@node_rewriter([ger, ger_destructive])
348-
def use_c_ger(fgraph, node):
349-
if not config.blas__ldflags:
350-
return
351-
# Only float32 and float64 are supported for now.
352-
if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
353-
return [CGer(False)(*node.inputs)]
354-
if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
355-
return [CGer(True)(*node.inputs)]
356-
357-
358-
@node_rewriter([CGer(False)])
359-
def make_c_ger_destructive(fgraph, node):
360-
if isinstance(node.op, CGer) and not node.op.destructive:
361-
return [cger_inplace(*node.inputs)]
362-
363-
364337
# ##### ####### #######
365338
# GEMV
366339
# ##### ####### #######
@@ -697,48 +670,3 @@ def check_force_gemv_init():
697670

698671

699672
check_force_gemv_init._force_init_beta = None
700-
701-
702-
@node_rewriter([gemv_inplace, gemv_no_inplace])
703-
def use_c_gemv(fgraph, node):
704-
if not config.blas__ldflags:
705-
return
706-
# Only float32 and float64 are supported for now.
707-
if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
708-
return [cgemv_no_inplace(*node.inputs)]
709-
if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
710-
return [cgemv_inplace(*node.inputs)]
711-
712-
713-
@node_rewriter([CGemv(inplace=False)])
714-
def make_c_gemv_destructive(fgraph, node):
715-
if isinstance(node.op, CGemv) and not node.op.inplace:
716-
inputs = list(node.inputs)
717-
dest = inputs[0]
718-
if (
719-
dest.owner
720-
and isinstance(dest.owner.op, at.AllocEmpty)
721-
and len(fgraph.clients[dest]) > 1
722-
):
723-
inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)
724-
725-
return [cgemv_inplace(*inputs)]
726-
727-
728-
# ##### ####### #######
729-
# Optimizers
730-
# ##### ####### #######
731-
732-
blas_optdb.register(
733-
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
734-
)
735-
736-
# this matches the InplaceBlasOpt defined in blas.py
737-
optdb.register(
738-
"c_blas_destructive",
739-
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
740-
"fast_run",
741-
"inplace",
742-
"c_blas",
743-
position=70.0,
744-
)

pytensor/tensor/blas_scipy.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44

55
import numpy as np
66

7-
from pytensor.graph.rewriting.basic import in2out
8-
from pytensor.tensor.blas import (
9-
Ger,
10-
blas_optdb,
11-
ger,
12-
ger_destructive,
13-
have_fblas,
14-
node_rewriter,
15-
optdb,
16-
)
7+
from pytensor.tensor.blas import Ger, have_fblas
178

189

1910
if have_fblas:
@@ -56,36 +47,3 @@ def perform(self, node, inputs, output_storage):
5647

5748
scipy_ger_no_inplace = ScipyGer(False)
5849
scipy_ger_inplace = ScipyGer(True)
59-
60-
61-
@node_rewriter([ger, ger_destructive])
62-
def use_scipy_ger(fgraph, node):
63-
if node.op == ger:
64-
return [scipy_ger_no_inplace(*node.inputs)]
65-
66-
67-
@node_rewriter([scipy_ger_no_inplace])
68-
def make_ger_destructive(fgraph, node):
69-
if node.op == scipy_ger_no_inplace:
70-
return [scipy_ger_inplace(*node.inputs)]
71-
72-
73-
use_scipy_blas = in2out(use_scipy_ger)
74-
make_scipy_blas_destructive = in2out(make_ger_destructive)
75-
76-
if have_fblas:
77-
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
78-
# sucks, but it is almost always present.
79-
# C implementations should be scheduled earlier than this, so that they take
80-
# precedence. Once the original Ger is replaced, then these optimizations
81-
# have no effect.
82-
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)
83-
84-
# this matches the InplaceBlasOpt defined in blas.py
85-
optdb.register(
86-
"make_scipy_blas_destructive",
87-
make_scipy_blas_destructive,
88-
"fast_run",
89-
"inplace",
90-
position=70.0,
91-
)

pytensor/tensor/rewriting/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import pytensor.tensor.rewriting.basic
2+
import pytensor.tensor.rewriting.blas
3+
import pytensor.tensor.rewriting.blas_c
4+
import pytensor.tensor.rewriting.blas_scipy
25
import pytensor.tensor.rewriting.elemwise
36
import pytensor.tensor.rewriting.extra_ops
47

0 commit comments

Comments
 (0)