diff --git a/CHANGES b/CHANGES index 82f99d2a53..34d568e1a5 100644 --- a/CHANGES +++ b/CHANGES @@ -1,6 +1,8 @@ Next release ============ +* ENH: Improve JSON interfaces: default settings when reading and consistent output creation + when writing (https://github.com/nipy/nipype/pull/1047) * FIX: AddCSVRow problems when using infields (https://github.com/nipy/nipype/pull/1028) * FIX: Removed unused ANTS registration flag (https://github.com/nipy/nipype/pull/999) * FIX: Amend create_tbss_non_fa() workflow to match FSL's tbss_non_fa command. (https://github.com/nipy/nipype/pull/1033) diff --git a/nipype/interfaces/io.py b/nipype/interfaces/io.py index 4872310b42..002582018c 100644 --- a/nipype/interfaces/io.py +++ b/nipype/interfaces/io.py @@ -1804,8 +1804,9 @@ def _get_ssh_client(self): class JSONFileGrabberInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec): - in_file = File(exists=True, mandatory=True, - desc='JSON source file') + in_file = File(exists=True, desc='JSON source file') + defaults = traits.Dict(desc=('JSON dictionary that sets default output' + 'values, overridden by values found in in_file')) class JSONFileGrabber(IOBase): @@ -1819,12 +1820,15 @@ class JSONFileGrabber(IOBase): >>> from nipype.interfaces.io import JSONFileGrabber >>> jsonSource = JSONFileGrabber() + >>> jsonSource.inputs.defaults = {'param1': u'overrideMe', 'param3': 1.0} + >>> res = jsonSource.run() + >>> res.outputs.get() + {'param3': 1.0, 'param1': u'overrideMe'} >>> jsonSource.inputs.in_file = 'jsongrabber.txt' >>> res = jsonSource.run() - >>> print res.outputs.param1 - exampleStr - >>> print res.outputs.param2 - 4 + >>> res.outputs.get() + {'param3': 1.0, 'param2': 4, 'param1': u'exampleStr'} + """ input_spec = JSONFileGrabberInputSpec @@ -1834,22 +1838,41 @@ class JSONFileGrabber(IOBase): def _list_outputs(self): import json - with open(self.inputs.in_file, 'r') as f: - data = json.load(f) + outputs = {} + if isdefined(self.inputs.in_file): + with open(self.inputs.in_file, 'r') as f: + data = json.load(f) - if not isinstance(data, dict): - raise RuntimeError('JSON input has no dictionary structure') + if not isinstance(data, dict): + raise RuntimeError('JSON input has no dictionary structure') - outputs = {} - for key, value in data.iteritems(): - outputs[key] = value + for key, value in data.iteritems(): + outputs[key] = value + + if isdefined(self.inputs.defaults): + defaults = self.inputs.defaults + for key, value in defaults.iteritems(): + if key not in outputs.keys(): + outputs[key] = value return outputs class JSONFileSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec): out_file = File(desc='JSON sink file') - in_dict = traits.Dict(desc='input JSON dictionary') + in_dict = traits.Dict(value={}, usedefault=True, + desc='input JSON dictionary') + _outputs = traits.Dict(value={}, usedefault=True) + + def __setattr__(self, key, value): + if key not in self.copyable_trait_names(): + if not isdefined(value): + super(JSONFileSinkInputSpec, self).__setattr__(key, value) + self._outputs[key] = value + else: + if key in self._outputs: + self._outputs[key] = value + super(JSONFileSinkInputSpec, self).__setattr__(key, value) class JSONFileSinkOutputSpec(TraitedSpec): @@ -1858,7 +1881,10 @@ class JSONFileSinkOutputSpec(TraitedSpec): class JSONFileSink(IOBase): - """ Very simple frontend for storing values into a JSON file. + """ + Very simple frontend for storing values into a JSON file. + Entries already existing in in_dict will be overridden by matching + entries dynamically added as inputs. .. warning:: @@ -1885,34 +1911,52 @@ class JSONFileSink(IOBase): input_spec = JSONFileSinkInputSpec output_spec = JSONFileSinkOutputSpec - def __init__(self, input_names=[], **inputs): + def __init__(self, infields=[], force_run=True, **inputs): super(JSONFileSink, self).__init__(**inputs) - self._input_names = filename_to_list(input_names) - add_traits(self.inputs, [name for name in self._input_names]) + self._input_names = infields + + undefined_traits = {} + for key in infields: + self.inputs.add_trait(key, traits.Any) + self.inputs._outputs[key] = Undefined + undefined_traits[key] = Undefined + self.inputs.trait_set(trait_change_notify=False, **undefined_traits) + + if force_run: + self._always_run = True + + def _process_name(self, name, val): + if '.' in name: + newkeys = name.split('.') + name = newkeys.pop(0) + nested_dict = {newkeys.pop(): val} + + for nk in reversed(newkeys): + nested_dict = {nk: nested_dict} + val = nested_dict + + return name, val def _list_outputs(self): import json import os.path as op + if not isdefined(self.inputs.out_file): out_file = op.abspath('datasink.json') else: out_file = self.inputs.out_file - out_dict = dict() + out_dict = self.inputs.in_dict - if isdefined(self.inputs.in_dict): - if isinstance(self.inputs.in_dict, dict): - out_dict = self.inputs.in_dict - else: - for name in self._input_names: - val = getattr(self.inputs, name) - val = val if isdefined(val) else 'undefined' - out_dict[name] = val + # Overwrite in_dict entries automatically + for key, val in self.inputs._outputs.items(): + if not isdefined(val) or key == 'trait_added': + continue + key, val = self._process_name(key, val) + out_dict[key] = val with open(out_file, 'w') as f: json.dump(out_dict, f) outputs = self.output_spec().get() outputs['out_file'] = out_file return outputs - - diff --git a/nipype/interfaces/tests/test_auto_JSONFileGrabber.py b/nipype/interfaces/tests/test_auto_JSONFileGrabber.py index 64d6057f5a..f1872e791a 100644 --- a/nipype/interfaces/tests/test_auto_JSONFileGrabber.py +++ b/nipype/interfaces/tests/test_auto_JSONFileGrabber.py @@ -3,11 +3,11 @@ from nipype.interfaces.io import JSONFileGrabber def test_JSONFileGrabber_inputs(): - input_map = dict(ignore_exception=dict(nohash=True, + input_map = dict(defaults=dict(), + ignore_exception=dict(nohash=True, usedefault=True, ), - in_file=dict(mandatory=True, - ), + in_file=dict(), ) inputs = JSONFileGrabber.input_spec() diff --git a/nipype/interfaces/tests/test_auto_JSONFileSink.py b/nipype/interfaces/tests/test_auto_JSONFileSink.py index fb95f03017..7c8cd80f98 100644 --- a/nipype/interfaces/tests/test_auto_JSONFileSink.py +++ b/nipype/interfaces/tests/test_auto_JSONFileSink.py @@ -3,10 +3,13 @@ from nipype.interfaces.io import JSONFileSink def test_JSONFileSink_inputs(): - input_map = dict(ignore_exception=dict(nohash=True, + input_map = dict(_outputs=dict(usedefault=True, + ), + ignore_exception=dict(nohash=True, usedefault=True, ), - in_dict=dict(), + in_dict=dict(usedefault=True, + ), out_file=dict(), ) inputs = JSONFileSink.input_spec() diff --git a/nipype/interfaces/tests/test_io.py b/nipype/interfaces/tests/test_io.py index 07fcb89040..072f304496 100644 --- a/nipype/interfaces/tests/test_io.py +++ b/nipype/interfaces/tests/test_io.py @@ -238,3 +238,41 @@ def test_freesurfersource(): yield assert_equal, fss.inputs.hemi, 'both' yield assert_equal, fss.inputs.subject_id, Undefined yield assert_equal, fss.inputs.subjects_dir, Undefined + + +def test_jsonsink(): + import json + import os + + ds = nio.JSONFileSink() + yield assert_equal, ds.inputs._outputs, {} + ds = nio.JSONFileSink(in_dict={'foo': 'var'}) + yield assert_equal, ds.inputs.in_dict, {'foo': 'var'} + ds = nio.JSONFileSink(infields=['test']) + yield assert_true, 'test' in ds.inputs.copyable_trait_names() + + curdir = os.getcwd() + outdir = mkdtemp() + os.chdir(outdir) + js = nio.JSONFileSink(infields=['test'], in_dict={'foo': 'var'}) + js.inputs.new_entry = 'someValue' + setattr(js.inputs, 'contrasts.alt', 'someNestedValue') + res = js.run() + + with open(res.outputs.out_file, 'r') as f: + data = json.load(f) + yield assert_true, data == {"contrasts": {"alt": "someNestedValue"}, "foo": "var", "new_entry": "someValue"} + + js = nio.JSONFileSink(infields=['test'], in_dict={'foo': 'var'}) + js.inputs.new_entry = 'someValue' + js.inputs.test = 'testInfields' + setattr(js.inputs, 'contrasts.alt', 'someNestedValue') + res = js.run() + + with open(res.outputs.out_file, 'r') as f: + data = json.load(f) + yield assert_true, data == {"test": "testInfields", "contrasts": {"alt": "someNestedValue"}, "foo": "var", "new_entry": "someValue"} + + os.chdir(curdir) + shutil.rmtree(outdir) +