Skip to content

Commit 48c80b4

Browse files
committed
add overwrite argument to json load
1 parent 678b710 commit 48c80b4

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

nipype/interfaces/base.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656

5757
iflogger = logging.getLogger('interface')
5858

59+
PY35 = sys.version_info >= (3, 5)
60+
5961
if runtime_profile:
6062
try:
6163
import psutil
@@ -762,12 +764,13 @@ def __init__(self, from_file=None, **inputs):
762764
if not self.input_spec:
763765
raise Exception('No input_spec in class: %s' %
764766
self.__class__.__name__)
767+
765768
self.inputs = self.input_spec(**inputs)
766769
self.estimated_memory_gb = 1
767770
self.num_threads = 1
768771

769772
if from_file is not None:
770-
self.load_inputs_from_json(from_file)
773+
self.load_inputs_from_json(from_file, overwrite=False)
771774

772775

773776
@classmethod
@@ -1152,24 +1155,33 @@ def version(self):
11521155
self.__class__.__name__)
11531156
return self._version
11541157

1155-
def load_inputs_from_json(self, json_file):
1158+
def load_inputs_from_json(self, json_file, overwrite=True):
11561159
"""
11571160
A convenient way to load pre-set inputs from a JSON file.
11581161
"""
11591162

11601163
with open(json_file) as fhandle:
11611164
inputs_dict = json.load(fhandle)
11621165

1163-
for key, val in list(inputs_dict.items()):
1166+
for key, newval in list(inputs_dict.items()):
11641167
if not hasattr(self.inputs, key):
1165-
setattr(self.inputs, key, val)
1168+
continue
1169+
val = getattr(self.inputs, key)
1170+
if overwrite or not isdefined(val):
1171+
setattr(self.inputs, key, newval)
11661172

11671173
def save_inputs_to_json(self, json_file):
11681174
"""
11691175
A convenient way to save current inputs to a JSON file.
11701176
"""
1177+
inputs = self.inputs.get()
1178+
for key, val in list(inputs.items()):
1179+
if not isdefined(val):
1180+
inputs.pop(key, None)
1181+
1182+
iflogger.debug('saving inputs {}', inputs)
11711183
with open(json_file, 'w') as fhandle:
1172-
json.dump(self.inputs.get(), fhandle, indent=4)
1184+
json.dump(inputs, fhandle, indent=4)
11731185

11741186

11751187
class Stream(object):

nipype/interfaces/tests/test_base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,11 @@ def test_BaseInterface_load_save_inputs():
459459
tmp_dir = tempfile.mkdtemp()
460460
tmp_json = os.path.join(tmp_dir, 'settings.json')
461461

462+
def _rem_undefined(indict):
463+
for key, val in list(indict.items()):
464+
if not nib.isdefined(val):
465+
indict.pop(key, None)
466+
return indict
462467

463468
class InputSpec(nib.TraitedSpec):
464469
input1 = nib.traits.Int()
@@ -472,16 +477,20 @@ class DerivedInterface(nib.BaseInterface):
472477
def __init__(self, **inputs):
473478
super(DerivedInterface, self).__init__(**inputs)
474479

475-
inputs_dict = {'input1': 12, 'input2': 3.4, 'input3': True,
480+
inputs_dict = {'input1': 12, 'input3': True,
476481
'input4': 'some string'}
477482
bif = DerivedInterface(**inputs_dict)
478483
bif.save_inputs_to_json(tmp_json)
479484
bif2 = DerivedInterface()
480485
bif2.load_inputs_from_json(tmp_json)
481-
yield assert_equal, inputs_dict, bif2.inputs.get()
486+
yield assert_equal, _rem_undefined(bif2.inputs.get()), inputs_dict
482487

483488
bif3 = DerivedInterface(from_file=tmp_json)
484-
yield assert_equal, inputs_dict, bif3.inputs.get()
489+
yield assert_equal, _rem_undefined(bif3.inputs.get()), inputs_dict
490+
491+
inputs_dict.update({'input4': 'some other string'})
492+
bif4 = DerivedInterface(from_file=tmp_json, input4='some other string')
493+
yield assert_equal, _rem_undefined(bif4.inputs.get()), inputs_dict
485494

486495
def assert_not_raises(fn, *args, **kwargs):
487496
fn(*args, **kwargs)

0 commit comments

Comments
 (0)