Skip to content

Commit 1788edf

Browse files
committed
fix json dump, add tests to check overwrite
1 parent 701dbc5 commit 1788edf

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

nipype/interfaces/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
iflogger = logging.getLogger('interface')
4949

5050
PY35 = sys.version_info >= (3, 5)
51+
PY3 = sys.version_info[0] > 2
5152

5253
if runtime_profile:
5354
try:
@@ -1197,8 +1198,8 @@ def save_inputs_to_json(self, json_file):
11971198
"""
11981199
inputs = self.inputs.get_traitsfree()
11991200
iflogger.debug('saving inputs {}', inputs)
1200-
with open(json_file, 'w') as fhandle:
1201-
json.dump(inputs, fhandle, indent=4)
1201+
with open(json_file, 'w' if PY3 else 'wb') as fhandle:
1202+
json.dump(inputs, fhandle, indent=4, ensure_ascii=False)
12021203

12031204

12041205
class Stream(object):

nipype/interfaces/tests/test_base.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,6 @@ def test_BaseInterface_load_save_inputs():
461461
tmp_dir = tempfile.mkdtemp()
462462
tmp_json = os.path.join(tmp_dir, 'settings.json')
463463

464-
def _rem_undefined(indict):
465-
for key, val in list(indict.items()):
466-
if not nib.isdefined(val):
467-
indict.pop(key, None)
468-
return indict
469-
470464
class InputSpec(nib.TraitedSpec):
471465
input1 = nib.traits.Int()
472466
input2 = nib.traits.Float()
@@ -485,14 +479,23 @@ def __init__(self, **inputs):
485479
bif.save_inputs_to_json(tmp_json)
486480
bif2 = DerivedInterface()
487481
bif2.load_inputs_from_json(tmp_json)
488-
yield assert_equal, _rem_undefined(bif2.inputs.get()), inputs_dict
482+
yield assert_equal, bif2.inputs.get_traitsfree(), inputs_dict
489483

490484
bif3 = DerivedInterface(from_file=tmp_json)
491-
yield assert_equal, _rem_undefined(bif3.inputs.get()), inputs_dict
485+
yield assert_equal, bif3.inputs.get_traitsfree(), inputs_dict
486+
487+
inputs_dict2 = inputs_dict.copy()
488+
inputs_dict2.update({'input4': 'some other string'})
489+
bif4 = DerivedInterface(from_file=tmp_json, input4=inputs_dict2['input4'])
490+
yield assert_equal, bif4.inputs.get_traitsfree(), inputs_dict2
491+
492+
bif5 = DerivedInterface(input4=inputs_dict2['input4'])
493+
bif5.load_inputs_from_json(tmp_json, overwrite=False)
494+
yield assert_equal, bif5.inputs.get_traitsfree(), inputs_dict2
492495

493-
inputs_dict.update({'input4': 'some other string'})
494-
bif4 = DerivedInterface(from_file=tmp_json, input4='some other string')
495-
yield assert_equal, _rem_undefined(bif4.inputs.get()), inputs_dict
496+
bif6 = DerivedInterface(input4=inputs_dict2['input4'])
497+
bif6.load_inputs_from_json(tmp_json)
498+
yield assert_equal, bif6.inputs.get_traitsfree(), inputs_dict
496499

497500
def assert_not_raises(fn, *args, **kwargs):
498501
fn(*args, **kwargs)

0 commit comments

Comments
 (0)