Skip to content

Commit 7a0175a

Browse files
committed
Simplify _ChangeFlagDecorator
1 parent 158a7d0 commit 7a0175a

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed

pytensor/configparser.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,7 @@ class ConfigAccessViolation(AttributeError):
3232

3333

3434
class _ChangeFlagsDecorator:
35-
def __init__(self, *args, _root=None, **kwargs):
36-
# the old API supported passing a dict as the first argument:
37-
if args:
38-
assert len(args) == 1 and isinstance(args[0], dict)
39-
kwargs = dict(**args[0], **kwargs)
35+
def __init__(self, _root=None, **kwargs):
4036
self.confs = {k: _root._config_var_dict[k] for k in kwargs}
4137
self.new_vals = kwargs
4238
self._root = _root
@@ -310,14 +306,14 @@ def fetch_val_for_key(self, key, delete_key: bool = False):
310306
except (NoOptionError, NoSectionError):
311307
raise KeyError(key)
312308

313-
def change_flags(self, *args, **kwargs) -> _ChangeFlagsDecorator:
309+
def change_flags(self, **kwargs) -> _ChangeFlagsDecorator:
314310
"""
315311
Use this as a decorator or context manager to change the value of
316312
PyTensor config variables.
317313
318314
Useful during tests.
319315
"""
320-
return _ChangeFlagsDecorator(*args, _root=self, **kwargs)
316+
return _ChangeFlagsDecorator(_root=self, **kwargs)
321317

322318
def warn_unused_flags(self):
323319
for key in self._flags_dict:

tests/link/c/test_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,6 @@ def test_op_with_cenumtype(self):
287287
assert val_billion == val_million * 1000
288288
assert val_two_billions == val_billion * 2
289289

290-
@pytensor.config.change_flags(**{"cmodule__debug": True})
290+
@pytensor.config.change_flags(cmodule__debug=True)
291291
def test_op_with_cenumtype_debug(self):
292292
self.test_op_with_cenumtype()

tests/tensor/test_blas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def compute_ref(
514514
C = self.get_value(C, transpose_C, slice_C)
515515
return alpha * np.dot(A, B) + beta * C
516516

517-
@config.change_flags({"blas__ldflags": ""})
517+
@config.change_flags(blas__ldflags="")
518518
def run_gemm(
519519
self,
520520
dtype,

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_config_context():
168168

169169
with root.change_flags(test__config_context="new_value"):
170170
assert root.test__config_context == "new_value"
171-
with root.change_flags({"test__config_context": "new_value2"}):
171+
with root.change_flags(test__config_context="new_value2"):
172172
assert root.test__config_context == "new_value2"
173173
assert root.test__config_context == "new_value"
174174
assert root.test__config_context == "test_default"

0 commit comments

Comments
 (0)