Skip to content

Commit e798109

Browse files
committed
Reformat with black
1 parent efa822c commit e798109

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recu
269269
data = v.data
270270
if data.ndim != 0:
271271
raise NotScalarConstantError()
272-
return get_underlying_scalar_constant(v, elemwise, only_process_constants, max_recur)
272+
return get_underlying_scalar_constant(
273+
v, elemwise, only_process_constants, max_recur
274+
)
273275

274276

275277
def get_underlying_scalar_constant(

pytensor/tensor/rewriting/elemwise.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
)
2323
from pytensor.graph.rewriting.db import SequenceDB
2424
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
25-
from pytensor.tensor.basic import MakeVector, alloc, cast, get_underlying_scalar_constant
25+
from pytensor.tensor.basic import (
26+
MakeVector,
27+
alloc,
28+
cast,
29+
get_underlying_scalar_constant,
30+
)
2631
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2732
from pytensor.tensor.exceptions import NotScalarConstantError
2833
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize

pytensor/tensor/rewriting/math.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,7 +1493,9 @@ def investigate(node):
14931493
and investigate(node.inputs[0].owner)
14941494
):
14951495
try:
1496-
cst = get_underlying_scalar_constant(node.inputs[1], only_process_constants=True)
1496+
cst = get_underlying_scalar_constant(
1497+
node.inputs[1], only_process_constants=True
1498+
)
14971499

14981500
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
14991501

@@ -2329,7 +2331,9 @@ def local_abs_merge(fgraph, node):
23292331
inputs.append(i.owner.inputs[0])
23302332
elif isinstance(i, Constant):
23312333
try:
2332-
const = get_underlying_scalar_constant(i, only_process_constants=True)
2334+
const = get_underlying_scalar_constant(
2335+
i, only_process_constants=True
2336+
)
23332337
except NotScalarConstantError:
23342338
return False
23352339
if not (const >= 0).all():

pytensor/tensor/rewriting/subtensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1729,7 +1729,10 @@ def local_join_subtensors(fgraph, node):
17291729
if step is None:
17301730
continue
17311731
try:
1732-
if get_underlying_scalar_constant(step, only_process_constants=True) != 1:
1732+
if (
1733+
get_underlying_scalar_constant(step, only_process_constants=True)
1734+
!= 1
1735+
):
17331736
return None
17341737
except NotScalarConstantError:
17351738
return None

tests/tensor/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3423,10 +3423,10 @@ def test_None_and_NoneConst(self, only_process_constants):
34233423
None, only_process_constants=only_process_constants
34243424
)
34253425
assert (
3426-
get_underlying_scalar_constant(
3426+
get_underlying_scalar_constant(
34273427
NoneConst, only_process_constants=only_process_constants
34283428
)
3429-
is None
3429+
is None
34303430
)
34313431

34323432

0 commit comments

Comments
 (0)