Skip to content

Commit 004281a

Browse files
committed
Some small formatting and style changes
1 parent 95e2e64 commit 004281a

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@
4949
def numba_njit(*args, **kwargs):
5050

5151
kwargs = kwargs.copy()
52-
if "cache" not in kwargs:
53-
kwargs["cache"] = config.numba__cache
52+
kwargs.setdefault("cache", config.numba__cache)
5453

5554
if len(args) > 0 and callable(args[0]):
5655
return numba.njit(*args[1:], **kwargs)(args[0])

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor import config
88
from pytensor.link.numba.dispatch import basic as numba_basic
99
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
10+
from pytensor.raise_op import CheckAndRaise
1011
from pytensor.tensor.extra_ops import (
1112
Bartlett,
1213
BroadcastTo,
@@ -19,7 +20,6 @@
1920
Unique,
2021
UnravelIndex,
2122
)
22-
from pytensor.raise_op import CheckAndRaise
2323

2424

2525
@numba_funcify.register(Bartlett)
@@ -48,11 +48,13 @@ def numba_funcify_CumOp(op, node, **kwargs):
4848
if mode == "add":
4949

5050
if ndim == 1:
51+
5152
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
5253
def cumop(x):
5354
return np.cumsum(x)
5455

5556
else:
57+
5658
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
5759
def cumop(x):
5860
out_dtype = x.dtype
@@ -70,11 +72,13 @@ def cumop(x):
7072

7173
else:
7274
if ndim == 1:
75+
7376
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
7477
def cumop(x):
7578
return np.cumprod(x)
7679

7780
else:
81+
7882
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
7983
def cumop(x):
8084
out_dtype = x.dtype

pytensor/link/numba/dispatch/scalar.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
144144
signature = create_numba_signature(node, force_scalar=True)
145145

146146
return numba_basic.numba_njit(
147-
signature, inline="always", fastmath=config.numba__fastmath, cache=False,
147+
signature,
148+
inline="always",
149+
fastmath=config.numba__fastmath,
150+
cache=False,
148151
)(scalar_op_fn)
149152

150153

pytensor/sparse/sandbox/sp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True):
182182
# taking into account multiple
183183
# input features
184184
col = int(
185-
iy * inshp[2] + ix + fmapi * np.prod(inshp[1:], dtype=int)
185+
iy * inshp[2]
186+
+ ix
187+
+ fmapi * np.prod(inshp[1:], dtype=int)
186188
)
187189

188190
# convert oy,ox values to output

0 commit comments

Comments
 (0)