diff --git a/CHANGES b/CHANGES index 061d161a06..01a09b735a 100644 --- a/CHANGES +++ b/CHANGES @@ -1,6 +1,7 @@ Upcoming release (0.14.1) ================ +* MAINT: Cleaning / simplify ``Node`` (https://github.com/nipy/nipype/pull/#2325) 0.14.0 (November 29, 2017) ========================== diff --git a/nipype/interfaces/freesurfer/tests/test_FSSurfaceCommand.py b/nipype/interfaces/freesurfer/tests/test_FSSurfaceCommand.py index acaa5d466d..70701e5f57 100644 --- a/nipype/interfaces/freesurfer/tests/test_FSSurfaceCommand.py +++ b/nipype/interfaces/freesurfer/tests/test_FSSurfaceCommand.py @@ -29,9 +29,11 @@ def test_FSSurfaceCommand_inputs(): @pytest.mark.skipif(fs.no_freesurfer(), reason="freesurfer is not installed") -def test_associated_file(): +def test_associated_file(tmpdir): fssrc = FreeSurferSource(subjects_dir=fs.Info.subjectsdir(), subject_id='fsaverage', hemi='lh') + fssrc.base_dir = tmpdir.strpath + fssrc.resource_monitor = False fsavginfo = fssrc.run().outputs.get() diff --git a/nipype/interfaces/spm/base.py b/nipype/interfaces/spm/base.py index 391528e83b..7882fa1280 100644 --- a/nipype/interfaces/spm/base.py +++ b/nipype/interfaces/spm/base.py @@ -29,8 +29,11 @@ # Local imports from ... import logging from ...utils import spm_docs as sd, NUMPY_MMAP -from ..base import (BaseInterface, traits, isdefined, InputMultiPath, - BaseInterfaceInputSpec, Directory, Undefined, ImageFile) +from ..base import ( + BaseInterface, traits, isdefined, InputMultiPath, + BaseInterfaceInputSpec, Directory, Undefined, + ImageFile, PackageInfo +) from ..matlab import MatlabCommand from ...external.due import due, Doi, BibTeX @@ -123,12 +126,37 @@ def scans_for_fnames(fnames, keep4d=False, separate_sessions=False): return flist -class Info(object): +class Info(PackageInfo): """Handles SPM version information """ - @staticmethod - def version(matlab_cmd=None, paths=None, use_mcr=None): - """Returns the path to the SPM directory in the Matlab path + _path = None + _name = None + + @classmethod + def path(klass, matlab_cmd=None, paths=None, use_mcr=None): + if klass._path: + return klass._path + klass.getinfo(matlab_cmd, paths, use_mcr) + return klass._path + + @classmethod + def version(klass, matlab_cmd=None, paths=None, use_mcr=None): + if klass._version: + return klass._version + klass.getinfo(matlab_cmd, paths, use_mcr) + return klass._version + + @classmethod + def name(klass, matlab_cmd=None, paths=None, use_mcr=None): + if klass._name: + return klass._name + klass.getinfo(matlab_cmd, paths, use_mcr) + return klass._name + + @classmethod + def getinfo(klass, matlab_cmd=None, paths=None, use_mcr=None): + """ + Returns the path to the SPM directory in the Matlab path If path not found, returns None. Parameters @@ -152,10 +180,17 @@ def version(matlab_cmd=None, paths=None, use_mcr=None): returns None of path not found """ + if klass._name and klass._path and klass._version: + return { + 'name': klass._name, + 'path': klass._path, + 'release': klass._version + } + use_mcr = use_mcr or 'FORCE_SPMMCR' in os.environ - matlab_cmd = ((use_mcr and os.getenv('SPMMCRCMD')) or - os.getenv('MATLABCMD') or - 'matlab -nodesktop -nosplash') + matlab_cmd = ( + (use_mcr and os.getenv('SPMMCRCMD')) or + os.getenv('MATLABCMD', 'matlab -nodesktop -nosplash')) mlab = MatlabCommand(matlab_cmd=matlab_cmd, resource_monitor=False) @@ -184,13 +219,17 @@ def version(matlab_cmd=None, paths=None, use_mcr=None): # No Matlab -- no spm logger.debug('%s', e) return None - else: - out = sd._strip_header(out.runtime.stdout) - out_dict = {} - for part in out.split('|'): - key, val = part.split(':') - out_dict[key] = val - return out_dict + + out = sd._strip_header(out.runtime.stdout) + out_dict = {} + for part in out.split('|'): + key, val = part.split(':') + out_dict[key] = val + + klass._version = out_dict['release'] + klass._path = out_dict['path'] + klass._name = out_dict['name'] + return out_dict def no_spm(): @@ -288,13 +327,15 @@ def _matlab_cmd_update(self): @property def version(self): - version_dict = Info.version(matlab_cmd=self.inputs.matlab_cmd, - paths=self.inputs.paths, - use_mcr=self.inputs.use_mcr) - if version_dict: - return '.'.join((version_dict['name'].split('SPM')[-1], - version_dict['release'])) - return version_dict + info_dict = Info.getinfo( + matlab_cmd=self.inputs.matlab_cmd, + paths=self.inputs.paths, + use_mcr=self.inputs.use_mcr + ) + if info_dict: + return '%s.%s' % ( + info_dict['name'].split('SPM')[-1], + info_dict['release']) @property def jobtype(self): diff --git a/nipype/interfaces/spm/tests/test_base.py b/nipype/interfaces/spm/tests/test_base.py index d1c517a0d3..57d0d88c21 100644 --- a/nipype/interfaces/spm/tests/test_base.py +++ b/nipype/interfaces/spm/tests/test_base.py @@ -16,12 +16,8 @@ from nipype.interfaces.spm.base import SPMCommandInputSpec from nipype.interfaces.base import traits -try: - matlab_cmd = os.environ['MATLABCMD'] -except: - matlab_cmd = 'matlab' - -mlab.MatlabCommand.set_default_matlab_cmd(matlab_cmd) +mlab.MatlabCommand.set_default_matlab_cmd( + os.getenv('MATLABCMD', 'matlab')) def test_scan_for_fnames(create_files_in_directory): @@ -35,10 +31,10 @@ def test_scan_for_fnames(create_files_in_directory): if not save_time: @pytest.mark.skipif(no_spm(), reason="spm is not installed") def test_spm_path(): - spm_path = spm.Info.version()['path'] + spm_path = spm.Info.path() if spm_path is not None: assert isinstance(spm_path, (str, bytes)) - assert 'spm' in spm_path + assert 'spm' in spm_path.lower() def test_use_mfile(): diff --git a/nipype/interfaces/spm/tests/test_model.py b/nipype/interfaces/spm/tests/test_model.py index e9e8a48849..307c4f1786 100644 --- a/nipype/interfaces/spm/tests/test_model.py +++ b/nipype/interfaces/spm/tests/test_model.py @@ -6,12 +6,8 @@ import nipype.interfaces.spm.model as spm import nipype.interfaces.matlab as mlab -try: - matlab_cmd = os.environ['MATLABCMD'] -except: - matlab_cmd = 'matlab' - -mlab.MatlabCommand.set_default_matlab_cmd(matlab_cmd) +mlab.MatlabCommand.set_default_matlab_cmd( + os.getenv('MATLABCMD', 'matlab')) def test_level1design(): diff --git a/nipype/interfaces/spm/tests/test_preprocess.py b/nipype/interfaces/spm/tests/test_preprocess.py index 4bf86285ad..f167ad521a 100644 --- a/nipype/interfaces/spm/tests/test_preprocess.py +++ b/nipype/interfaces/spm/tests/test_preprocess.py @@ -10,12 +10,8 @@ from nipype.interfaces.spm import no_spm import nipype.interfaces.matlab as mlab -try: - matlab_cmd = os.environ['MATLABCMD'] -except: - matlab_cmd = 'matlab' - -mlab.MatlabCommand.set_default_matlab_cmd(matlab_cmd) +mlab.MatlabCommand.set_default_matlab_cmd( + os.getenv('MATLABCMD', 'matlab')) def test_slicetiming(): @@ -88,7 +84,7 @@ def test_normalize12_list_outputs(create_files_in_directory): @pytest.mark.skipif(no_spm(), reason="spm is not installed") def test_segment(): - if spm.Info.version()['name'] == "SPM12": + if spm.Info.name() == "SPM12": assert spm.Segment()._jobtype == 'tools' assert spm.Segment()._jobname == 'oldseg' else: @@ -98,7 +94,7 @@ def test_segment(): @pytest.mark.skipif(no_spm(), reason="spm is not installed") def test_newsegment(): - if spm.Info.version()['name'] == "SPM12": + if spm.Info.name() == "SPM12": assert spm.NewSegment()._jobtype == 'spatial' assert spm.NewSegment()._jobname == 'preproc' else: diff --git a/nipype/pipeline/engine/nodes.py b/nipype/pipeline/engine/nodes.py index 5b972a7692..ed1fde9d28 100644 --- a/nipype/pipeline/engine/nodes.py +++ b/nipype/pipeline/engine/nodes.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: @@ -15,49 +14,53 @@ """ from __future__ import print_function, division, unicode_literals, absolute_import -from builtins import range, object, str, bytes, open +from builtins import range, str, bytes, open -from future import standard_library -standard_library.install_aliases() from collections import OrderedDict -from copy import deepcopy -import pickle -from glob import glob -import gzip import os import os.path as op import shutil -import errno import socket -from shutil import rmtree -import sys +from copy import deepcopy +from glob import glob + from tempfile import mkdtemp -from hashlib import sha1 +from future import standard_library from ... import config, logging -from ...utils.misc import (flatten, unflatten, str2bool) -from ...utils.filemanip import (save_json, FileNotFoundError, - filename_to_list, list_to_filename, - copyfiles, fnames_presuffix, loadpkl, - split_filename, load_json, savepkl, - write_rst_header, write_rst_dict, - write_rst_list, to_str, md5) -from ...interfaces.base import (traits, InputMultiPath, CommandLine, - Undefined, TraitedSpec, DynamicTraitedSpec, - Bunch, InterfaceResult, Interface, - TraitDictObject, TraitListObject, isdefined) -from .utils import (generate_expanded_graph, modify_paths, - export_graph, make_output_dir, write_workflow_prov, - clean_working_directory, format_dot, topological_sort, - get_print_name, merge_dict, evaluate_connect_function) +from ...utils.misc import flatten, unflatten, str2bool, dict_diff +from ...utils.filemanip import ( + md5, FileNotFoundError, filename_to_list, list_to_filename, + copyfiles, fnames_presuffix, loadpkl, split_filename, load_json, makedirs, + emptydirs, savepkl, to_str +) + +from ...interfaces.base import ( + traits, InputMultiPath, CommandLine, Undefined, DynamicTraitedSpec, + Bunch, InterfaceResult, Interface, isdefined +) +from .utils import ( + _parameterization_dir, + save_hashfile as _save_hashfile, + load_resultfile as _load_resultfile, + save_resultfile as _save_resultfile, + nodelist_runner as _node_runner, + strip_temp as _strip_temp, + write_report, + clean_working_directory, + merge_dict, evaluate_connect_function +) from .base import EngineBase +standard_library.install_aliases() + logger = logging.getLogger('workflow') class Node(EngineBase): - """Wraps interface objects for use in pipeline + """ + Wraps interface objects for use in pipeline A Node creates a sandbox-like directory for executing the underlying interface. It will copy or link inputs into this directory to ensure that @@ -149,38 +152,40 @@ def __init__(self, interface, name, iterables=None, itersource=None, multiprocessing pool """ - base_dir = None - if 'base_dir' in kwargs: - base_dir = kwargs['base_dir'] - super(Node, self).__init__(name, base_dir) - # Make sure an interface is set, and that it is an Interface if interface is None: raise IOError('Interface must be provided') if not isinstance(interface, Interface): raise IOError('interface must be an instance of an Interface') - self._interface = interface - self.name = name - self._result = None + super(Node, self).__init__(name, kwargs.get('base_dir')) + + self.name = name + self._interface = interface + self._hierarchy = None + self._got_inputs = False + self._originputs = None + self._output_dir = None self.iterables = iterables self.synchronize = synchronize self.itersource = itersource self.overwrite = overwrite self.parameterization = None - self.run_without_submitting = run_without_submitting self.input_source = {} - self.needed_outputs = [] self.plugin_args = {} + self.run_without_submitting = run_without_submitting self._mem_gb = mem_gb self._n_procs = n_procs + + # Downstream n_procs if hasattr(self._interface.inputs, 'num_threads') and self._n_procs is not None: self._interface.inputs.num_threads = self._n_procs + # Initialize needed_outputs + self.needed_outputs = [] if needed_outputs: self.needed_outputs = sorted(needed_outputs) - self._got_inputs = False @property def interface(self): @@ -189,12 +194,8 @@ def interface(self): @property def result(self): - if self._result: - return self._result - else: - cwd = self.output_dir() - result, _, _ = self._load_resultfile(cwd) - return result + """Get result from result file (do not hold it in memory)""" + return _load_resultfile(self.output_dir(), self.name)[0] @property def inputs(self): @@ -221,11 +222,11 @@ def n_procs(self): """Get the estimated number of processes/threads""" if self._n_procs is not None: return self._n_procs - elif hasattr(self._interface.inputs, 'num_threads') and isdefined( - self._interface.inputs.num_threads): + if hasattr(self._interface.inputs, + 'num_threads') and isdefined(self._interface.inputs.num_threads): return self._interface.inputs.num_threads - else: - return 1 + + return 1 @n_procs.setter def n_procs(self, value): @@ -238,6 +239,11 @@ def n_procs(self, value): def output_dir(self): """Return the location of the output directory for the node""" + # Output dir is cached + if self._output_dir: + return self._output_dir + + # Calculate & cache otherwise if self.base_dir is None: self.base_dir = mkdtemp() outputdir = self.base_dir @@ -246,54 +252,96 @@ def output_dir(self): if self.parameterization: params_str = ['{}'.format(p) for p in self.parameterization] if not str2bool(self.config['execution']['parameterize_dirs']): - params_str = [self._parameterization_dir(p) for p in params_str] + params_str = [_parameterization_dir(p) for p in params_str] outputdir = op.join(outputdir, *params_str) - return op.abspath(op.join(outputdir, self.name)) + + self._output_dir = op.abspath(op.join(outputdir, self.name)) + return self._output_dir def set_input(self, parameter, val): - """ Set interface input value""" - logger.debug('setting nodelevel(%s) input %s = %s', + """Set interface input value""" + logger.debug('[Node] %s - setting input %s = %s', self.name, parameter, to_str(val)) setattr(self.inputs, parameter, deepcopy(val)) def get_output(self, parameter): """Retrieve a particular output of the node""" - val = None - if self._result: - val = getattr(self._result.outputs, parameter) - else: - cwd = self.output_dir() - result, _, _ = self._load_resultfile(cwd) - if result and result.outputs: - val = getattr(result.outputs, parameter) - return val + return getattr(self.result.outputs, parameter, None) def help(self): - """ Print interface help""" + """Print interface help""" self._interface.help() def hash_exists(self, updatehash=False): + """ + Check if the interface has been run previously, and whether + cached results are viable for reuse + """ + # Get a dictionary with hashed filenames and a hashvalue # of the dictionary itself. hashed_inputs, hashvalue = self._get_hashval() outdir = self.output_dir() - if op.exists(outdir): - logger.debug('Output dir: %s', to_str(os.listdir(outdir))) - hashfiles = glob(op.join(outdir, '_0x*.json')) - logger.debug('Found hashfiles: %s', to_str(hashfiles)) - if len(hashfiles) > 1: - logger.info(hashfiles) - logger.info('Removing multiple hashfiles and forcing node to rerun') - for hashfile in hashfiles: - os.unlink(hashfile) hashfile = op.join(outdir, '_0x%s.json' % hashvalue) - logger.debug('Final hashfile: %s', hashfile) - if updatehash and op.exists(outdir): - logger.debug("Updating hash: %s", hashvalue) - for file in glob(op.join(outdir, '_0x*.json')): - os.remove(file) - self._save_hashfile(hashfile, hashed_inputs) - return op.exists(hashfile), hashvalue, hashfile, hashed_inputs + hash_exists = op.exists(hashfile) + + logger.debug('[Node] hash value=%s, exists=%s', hashvalue, hash_exists) + + if op.exists(outdir): + # Find previous hashfiles + globhashes = glob(op.join(outdir, '_0x*.json')) + unfinished = [path for path in globhashes if path.endswith('_unfinished.json')] + hashfiles = list(set(globhashes) - set(unfinished)) + if len(hashfiles) > 1: + for rmfile in hashfiles: + os.remove(rmfile) + + raise RuntimeError( + '[Node] Cache ERROR - Found %d previous hashfiles indicating ' + 'that the ``base_dir`` for this node went stale. Please re-run the ' + 'workflow.' % len(hashfiles)) + + # This should not happen, but clean up and break if so. + if unfinished and updatehash: + for rmfile in unfinished: + os.remove(rmfile) + + raise RuntimeError( + '[Node] Cache ERROR - Found unfinished hashfiles (%d) indicating ' + 'that the ``base_dir`` for this node went stale. Please re-run the ' + 'workflow.' % len(unfinished)) + + # Remove outdated hashfile + if hashfiles and hashfiles[0] != hashfile: + logger.info('[Node] Outdated hashfile found for "%s", removing and forcing node ' + 'to rerun.', self.fullname) + + # If logging is more verbose than INFO (20), print diff between hashes + loglevel = logger.getEffectiveLevel() + if loglevel < 20: # Lazy logging: only < INFO + split_out = split_filename(hashfiles[0]) + exp_hash_file_base = split_out[1] + exp_hash = exp_hash_file_base[len('_0x'):] + logger.log(loglevel, "[Node] Old/new hashes = %s/%s", exp_hash, hashvalue) + try: + prev_inputs = load_json(hashfiles[0]) + except Exception: + pass + else: + logger.log(loglevel, dict_diff(prev_inputs, hashed_inputs, 10)) + + os.remove(hashfiles[0]) + + # Update only possible if it exists + if hash_exists and updatehash: + logger.debug("[Node] Updating hash: %s", hashvalue) + _save_hashfile(hashfile, hashed_inputs) + + logger.debug( + 'updatehash=%s, overwrite=%s, always_run=%s, hash_exists=%s, ' + 'hash_method=%s', updatehash, self.overwrite, self._interface.always_run, + hash_exists, self.config['execution']['hash_method'].lower()) + return hash_exists, hashvalue, hashfile, hashed_inputs def run(self, updatehash=False): """Execute the node in its directory. @@ -302,145 +350,115 @@ def run(self, updatehash=False): ---------- updatehash: boolean - Update the hash stored in the output directory + When the hash stored in the output directory as a result of a previous run + does not match that calculated for this execution, updatehash=True only + updates the hash without re-running. """ - # check to see if output directory and hash exist if self.config is None: - self.config = deepcopy(config._sections) - else: - self.config = merge_dict(deepcopy(config._sections), self.config) - if not self._got_inputs: - self._get_inputs() - self._got_inputs = True + self.config = {} + self.config = merge_dict(deepcopy(config._sections), self.config) + self._get_inputs() + + # Check if output directory exists outdir = self.output_dir() - logger.info("Executing node %s in dir: %s", self.fullname, outdir) if op.exists(outdir): - logger.debug('Output dir: %s', to_str(os.listdir(outdir))) + logger.debug('Output directory (%s) exists and is %sempty,', + outdir, 'not ' * bool(os.listdir(outdir))) + + # Check hash, check whether run should be enforced + logger.info('[Node] Setting-up "%s" in "%s".', self.fullname, outdir) hash_info = self.hash_exists(updatehash=updatehash) hash_exists, hashvalue, hashfile, hashed_inputs = hash_info - logger.debug( - 'updatehash=%s, overwrite=%s, always_run=%s, hash_exists=%s', - updatehash, self.overwrite, self._interface.always_run, hash_exists) - if (not updatehash and (((self.overwrite is None and - self._interface.always_run) or - self.overwrite) or not - hash_exists)): - logger.debug("Node hash: %s", hashvalue) - - # by rerunning we mean only nodes that did finish to run previously - json_pat = op.join(outdir, '_0x*.json') - json_unfinished_pat = op.join(outdir, '_0x*_unfinished.json') - need_rerun = (op.exists(outdir) and not - isinstance(self, MapNode) and - len(glob(json_pat)) != 0 and - len(glob(json_unfinished_pat)) == 0) - if need_rerun: - logger.debug( - "Rerunning node:\n" - "updatehash = %s, self.overwrite = %s, self._interface.always_run = %s, " - "os.path.exists(%s) = %s, hash_method = %s", updatehash, self.overwrite, - self._interface.always_run, hashfile, op.exists(hashfile), - self.config['execution']['hash_method'].lower()) - log_debug = config.get('logging', 'workflow_level') == 'DEBUG' - if log_debug and not op.exists(hashfile): - exp_hash_paths = glob(json_pat) - if len(exp_hash_paths) == 1: - split_out = split_filename(exp_hash_paths[0]) - exp_hash_file_base = split_out[1] - exp_hash = exp_hash_file_base[len('_0x'):] - logger.debug("Previous node hash = %s", exp_hash) - try: - prev_inputs = load_json(exp_hash_paths[0]) - except: - pass - else: - logging.logdebug_dict_differences(prev_inputs, - hashed_inputs) - cannot_rerun = (str2bool( - self.config['execution']['stop_on_first_rerun']) and not - (self.overwrite is None and self._interface.always_run)) - if cannot_rerun: - raise Exception(("Cannot rerun when 'stop_on_first_rerun' " - "is set to True")) - hashfile_unfinished = op.join(outdir, - '_0x%s_unfinished.json' % - hashvalue) - if op.exists(hashfile): - os.remove(hashfile) - rm_outdir = (op.exists(outdir) and not - (op.exists(hashfile_unfinished) and - self._interface.can_resume) and not - isinstance(self, MapNode)) - if rm_outdir: - logger.debug("Removing old %s and its contents", outdir) - try: - rmtree(outdir) - except OSError as ex: - outdircont = os.listdir(outdir) - if ((ex.errno == errno.ENOTEMPTY) and (len(outdircont) == 0)): - logger.warn( - 'An exception was raised trying to remove old %s, but the path ' - 'seems empty. Is it an NFS mount?. Passing the exception.', outdir) - elif ((ex.errno == errno.ENOTEMPTY) and (len(outdircont) != 0)): - logger.debug( - 'Folder contents (%d items): %s', len(outdircont), outdircont) - raise ex - else: - raise ex + force_run = self.overwrite or (self.overwrite is None and self._interface.always_run) + + # If the node is cached, check on pklz files and finish + if hash_exists and (updatehash or not force_run): + logger.debug("Only updating node hashes or skipping execution") + inputs_file = op.join(outdir, '_inputs.pklz') + if not op.exists(inputs_file): + logger.debug('Creating inputs file %s', inputs_file) + savepkl(inputs_file, self.inputs.get_traitsfree()) + + node_file = op.join(outdir, '_node.pklz') + if not op.exists(node_file): + logger.debug('Creating node file %s', node_file) + savepkl(node_file, self) + + result = self._run_interface(execute=False, updatehash=updatehash) + logger.info('[Node] "%s" found cached%s.', self.fullname, + ' (and hash updated)' * updatehash) + return result - else: - logger.debug( - "%s found and can_resume is True or Node is a MapNode - resuming execution", - hashfile_unfinished) - if isinstance(self, MapNode): - # remove old json files - for filename in glob(op.join(outdir, '_0x*.json')): - os.unlink(filename) - outdir = make_output_dir(outdir) - self._save_hashfile(hashfile_unfinished, hashed_inputs) - self.write_report(report_type='preexec', cwd=outdir) - savepkl(op.join(outdir, '_node.pklz'), self) - savepkl(op.join(outdir, '_inputs.pklz'), - self.inputs.get_traitsfree()) - try: - self._run_interface() - except: - os.remove(hashfile_unfinished) - raise - shutil.move(hashfile_unfinished, hashfile) - self.write_report(report_type='postexec', cwd=outdir) - else: - if not op.exists(op.join(outdir, '_inputs.pklz')): - logger.debug('%s: creating inputs file', self.name) - savepkl(op.join(outdir, '_inputs.pklz'), - self.inputs.get_traitsfree()) - if not op.exists(op.join(outdir, '_node.pklz')): - logger.debug('%s: creating node file', self.name) - savepkl(op.join(outdir, '_node.pklz'), self) - logger.debug("Hashfile exists. Skipping execution") - self._run_interface(execute=False, updatehash=updatehash) - logger.debug('Finished running %s in dir: %s\n', self._id, outdir) - return self._result - - # Private functions - def _parameterization_dir(self, param): - """ - Returns the directory name for the given parameterization string as follows: - - If the parameterization is longer than 32 characters, then - return the SHA-1 hex digest. - - Otherwise, return the parameterization unchanged. - """ - if len(param) > 32: - return sha1(param.encode()).hexdigest() + # by rerunning we mean only nodes that did finish to run previously + if hash_exists and not isinstance(self, MapNode): + logger.debug('[Node] Rerunning "%s"', self.fullname) + if not force_run and str2bool(self.config['execution']['stop_on_first_rerun']): + raise Exception('Cannot rerun when "stop_on_first_rerun" is set to True') + + # Remove hashfile if it exists at this point (re-running) + if op.exists(hashfile): + os.remove(hashfile) + + # Hashfile while running + hashfile_unfinished = op.join( + outdir, '_0x%s_unfinished.json' % hashvalue) + + # Delete directory contents if this is not a MapNode or can't resume + rm_outdir = not isinstance(self, MapNode) and not ( + self._interface.can_resume and op.isfile(hashfile_unfinished)) + if rm_outdir: + emptydirs(outdir, noexist_ok=True) else: - return param + logger.debug('[%sNode] Resume - hashfile=%s', + 'Map' * int(isinstance(self, MapNode)), + hashfile_unfinished) + if isinstance(self, MapNode): + # remove old json files + for filename in glob(op.join(outdir, '_0x*.json')): + os.remove(filename) + + # Make sure outdir is created + makedirs(outdir, exist_ok=True) + + # Store runtime-hashfile, pre-execution report, the node and the inputs set. + _save_hashfile(hashfile_unfinished, hashed_inputs) + write_report(self, report_type='preexec', + is_mapnode=isinstance(self, MapNode)) + savepkl(op.join(outdir, '_node.pklz'), self) + savepkl(op.join(outdir, '_inputs.pklz'), + self.inputs.get_traitsfree()) + + try: + cwd = os.getcwd() + except OSError: + # Changing back to cwd is probably not necessary + # but this makes sure there's somewhere to change to. + cwd = op.split(outdir)[0] + logger.warning('Current folder "%s" does not exist, changing to "%s" instead.', + os.getenv('PWD', 'unknown'), cwd) + + os.chdir(outdir) + try: + result = self._run_interface(execute=True) + except Exception: + logger.warning('[Node] Error on "%s" (%s)', self.fullname, outdir) + # Tear-up after error + os.remove(hashfile_unfinished) + raise + finally: # Ensure we come back to the original CWD + os.chdir(cwd) + + # Tear-up after success + shutil.move(hashfile_unfinished, hashfile) + write_report(self, report_type='postexec', + is_mapnode=isinstance(self, MapNode)) + logger.info('[Node] Finished "%s".', self.fullname) + return result def _get_hashval(self): """Return a hash of the input state""" - if not self._got_inputs: - self._get_inputs() - self._got_inputs = True + self._get_inputs() hashed_inputs, hashvalue = self.inputs.get_hashval( hash_method=self.config['execution']['hash_method']) rm_extra = self.config['execution']['remove_unnecessary_outputs'] @@ -453,30 +471,15 @@ def _get_hashval(self): hashed_inputs.append(('needed_outputs', sorted_outputs)) return hashed_inputs, hashvalue - def _save_hashfile(self, hashfile, hashed_inputs): - try: - save_json(hashfile, hashed_inputs) - except (IOError, TypeError): - err_type = sys.exc_info()[0] - if err_type is TypeError: - # XXX - SG current workaround is to just - # create the hashed file and not put anything - # in it - with open(hashfile, 'wt') as fd: - fd.writelines(str(hashed_inputs)) - - logger.debug( - 'Unable to write a particular type to the json file') - else: - logger.critical('Unable to open the file in write mode: %s', - hashfile) - def _get_inputs(self): """Retrieve inputs from pointers to results file This mechanism can be easily extended/replaced to retrieve data from other data sources (e.g., XNAT, HTTP, etc.,.) """ + if self._got_inputs: + return + logger.debug('Setting node inputs') for key, info in list(self.input_source.items()): logger.debug('input: %s', key) @@ -488,9 +491,8 @@ def _get_inputs(self): output_name = info[1][0] value = getattr(results.outputs, output_name) if isdefined(value): - output_value = evaluate_connect_function(info[1][1], - info[1][2], - value) + output_value = evaluate_connect_function( + info[1][1], info[1][2], value) else: output_name = info[1] try: @@ -509,88 +511,17 @@ def _get_inputs(self): e.args = (e.args[0] + "\n" + '\n'.join(msg),) raise + # Successfully set inputs + self._got_inputs = True + def _run_interface(self, execute=True, updatehash=False): if updatehash: - return - old_cwd = os.getcwd() - os.chdir(self.output_dir()) - self._result = self._run_command(execute) - os.chdir(old_cwd) - - def _save_results(self, result, cwd): - resultsfile = op.join(cwd, 'result_%s.pklz' % self.name) - if result.outputs: - try: - outputs = result.outputs.get() - except TypeError: - outputs = result.outputs.dictcopy() # outputs was a bunch - result.outputs.set(**modify_paths(outputs, relative=True, - basedir=cwd)) - - savepkl(resultsfile, result) - logger.debug('saved results in %s', resultsfile) - - if result.outputs: - result.outputs.set(**outputs) + return self._load_results() + return self._run_command(execute) - def _load_resultfile(self, cwd): - """Load results if it exists in cwd - - Parameter - --------- - - cwd : working directory of node - - Returns - ------- - - result : InterfaceResult structure - aggregate : boolean indicating whether node should aggregate_outputs - attribute error : boolean indicating whether there was some mismatch in - versions of traits used to store result and hence node needs to - rerun - """ - aggregate = True - resultsoutputfile = op.join(cwd, 'result_%s.pklz' % self.name) - result = None - attribute_error = False - if op.exists(resultsoutputfile): - pkl_file = gzip.open(resultsoutputfile, 'rb') - try: - result = pickle.load(pkl_file) - except UnicodeDecodeError: - # Was this pickle created with Python 2.x? - pickle.load(pkl_file, fix_imports=True, encoding='utf-8') - logger.warn('Successfully loaded pickle in compatibility mode') - except (traits.TraitError, AttributeError, ImportError, - EOFError) as err: - if isinstance(err, (AttributeError, ImportError)): - attribute_error = True - logger.debug('attribute error: %s probably using ' - 'different trait pickled file', str(err)) - else: - logger.debug( - 'some file does not exist. hence trait cannot be set') - else: - if result.outputs: - try: - outputs = result.outputs.get() - except TypeError: - outputs = result.outputs.dictcopy() # outputs == Bunch - try: - result.outputs.set(**modify_paths(outputs, - relative=False, - basedir=cwd)) - except FileNotFoundError: - logger.debug('conversion to full path results in ' - 'non existent file') - aggregate = False - pkl_file.close() - logger.debug('Aggregate: %s', aggregate) - return result, aggregate, attribute_error - - def _load_results(self, cwd): - result, aggregate, attribute_error = self._load_resultfile(cwd) + def _load_results(self): + cwd = self.output_dir() + result, aggregate, attribute_error = _load_resultfile(cwd, self.name) # try aggregating first if aggregate: logger.debug('aggregating results') @@ -598,7 +529,7 @@ def _load_results(self, cwd): old_inputs = loadpkl(op.join(cwd, '_inputs.pklz')) self.inputs.trait_set(**old_inputs) if not isinstance(self, MapNode): - self._copyfiles_to_wd(cwd, True, linksonly=True) + self._copyfiles_to_wd(linksonly=True) aggouts = self._interface.aggregate_outputs( needed_outputs=self.needed_outputs) runtime = Bunch(cwd=cwd, @@ -610,177 +541,131 @@ def _load_results(self, cwd): runtime=runtime, inputs=self._interface.inputs.get_traitsfree(), outputs=aggouts) - self._save_results(result, cwd) + _save_resultfile(result, cwd, self.name) else: logger.debug('aggregating mapnode results') - self._run_interface() - result = self._result + result = self._run_interface() return result def _run_command(self, execute, copyfiles=True): - cwd = os.getcwd() - if execute and copyfiles: + + if not execute: + try: + result = self._load_results() + except (FileNotFoundError, AttributeError): + # if aggregation does not work, rerun the node + logger.info("[Node] Some of the outputs were not found: " + "rerunning node.") + copyfiles = False # OE: this was like this before, + execute = True # I'll keep them for safety + else: + logger.info( + "[Node] Cached - collecting precomputed outputs") + return result + + # Run command: either execute is true or load_results failed. + runtime = Bunch(returncode=1, + environ=dict(os.environ), + hostname=socket.gethostname()) + result = InterfaceResult( + interface=self._interface.__class__, + runtime=runtime, + inputs=self._interface.inputs.get_traitsfree()) + + outdir = self.output_dir() + if copyfiles: self._originputs = deepcopy(self._interface.inputs) - if execute: - runtime = Bunch(returncode=1, - environ=dict(os.environ), - hostname=socket.gethostname()) - result = InterfaceResult( - interface=self._interface.__class__, - runtime=runtime, - inputs=self._interface.inputs.get_traitsfree()) - self._result = result - if copyfiles: - self._copyfiles_to_wd(cwd, execute) - - message = 'Running node "%s" ("%s.%s")' - if issubclass(self._interface.__class__, CommandLine): - try: - cmd = self._interface.cmdline - except Exception as msg: - self._result.runtime.stderr = msg - raise - cmdfile = op.join(cwd, 'command.txt') - with open(cmdfile, 'wt') as fd: - print(cmd + "\n", file=fd) - message += ', a CommandLine Interface with command:\n%s' % cmd - logger.info(message + '.', self.name, self._interface.__module__, - self._interface.__class__.__name__) + self._copyfiles_to_wd(execute=execute) + + message = '[Node] Running "%s" ("%s.%s")' + if issubclass(self._interface.__class__, CommandLine): try: - result = self._interface.run() + cmd = self._interface.cmdline except Exception as msg: - self._save_results(result, cwd) - self._result.runtime.stderr = msg + result.runtime.stderr = '%s\n\n%s' % ( + getattr(result.runtime, 'stderr', ''), msg) + _save_resultfile(result, outdir, self.name) raise + cmdfile = op.join(outdir, 'command.txt') + with open(cmdfile, 'wt') as fd: + print(cmd + "\n", file=fd) + message += ', a CommandLine Interface with command:\n%s' % cmd + logger.info(message, self.name, self._interface.__module__, + self._interface.__class__.__name__) + try: + result = self._interface.run() + except Exception as msg: + result.runtime.stderr = '%s\n\n%s' % ( + getattr(result.runtime, 'stderr', ''), msg) + _save_resultfile(result, outdir, self.name) + raise + + dirs2keep = None + if isinstance(self, MapNode): + dirs2keep = [op.join(outdir, 'mapflow')] + + result.outputs = clean_working_directory( + result.outputs, outdir, + self._interface.inputs, + self.needed_outputs, + self.config, + dirs2keep=dirs2keep + ) + _save_resultfile(result, outdir, self.name) - dirs2keep = None - if isinstance(self, MapNode): - dirs2keep = [op.join(cwd, 'mapflow')] - result.outputs = clean_working_directory(result.outputs, cwd, - self._interface.inputs, - self.needed_outputs, - self.config, - dirs2keep=dirs2keep) - self._save_results(result, cwd) - else: - logger.info("Collecting precomputed outputs") - try: - result = self._load_results(cwd) - except (FileNotFoundError, AttributeError): - # if aggregation does not work, rerun the node - logger.info(("Some of the outputs were not found: " - "rerunning node.")) - result = self._run_command(execute=True, copyfiles=False) return result - def _strip_temp(self, files, wd): - out = [] - for f in files: - if isinstance(f, list): - out.append(self._strip_temp(f, wd)) - else: - out.append(f.replace(op.join(wd, '_tempinput'), wd)) - return out - - def _copyfiles_to_wd(self, outdir, execute, linksonly=False): - """ copy files over and change the inputs""" - if hasattr(self._interface, '_get_filecopy_info'): - logger.debug('copying files to wd [execute=%s, linksonly=%s]', - str(execute), str(linksonly)) - if execute and linksonly: - olddir = outdir - outdir = op.join(outdir, '_tempinput') - os.makedirs(outdir) - for info in self._interface._get_filecopy_info(): - files = self.inputs.get().get(info['key']) - if not isdefined(files): - continue - if files: - infiles = filename_to_list(files) - if execute: - if linksonly: - if not info['copy']: - newfiles = copyfiles(infiles, - [outdir], - copy=info['copy'], - create_new=True) - else: - newfiles = fnames_presuffix(infiles, - newpath=outdir) - newfiles = self._strip_temp( - newfiles, - op.abspath(olddir).split(op.sep)[-1]) - else: - newfiles = copyfiles(infiles, - [outdir], - copy=info['copy'], - create_new=True) + def _copyfiles_to_wd(self, execute=True, linksonly=False): + """copy files over and change the inputs""" + if not hasattr(self._interface, '_get_filecopy_info'): + # Nothing to be done + return + + logger.debug('copying files to wd [execute=%s, linksonly=%s]', + execute, linksonly) + + outdir = self.output_dir() + if execute and linksonly: + olddir = outdir + outdir = op.join(outdir, '_tempinput') + makedirs(outdir, exist_ok=True) + + for info in self._interface._get_filecopy_info(): + files = self.inputs.get().get(info['key']) + if not isdefined(files) or not files: + continue + + infiles = filename_to_list(files) + if execute: + if linksonly: + if not info['copy']: + newfiles = copyfiles(infiles, + [outdir], + copy=info['copy'], + create_new=True) else: - newfiles = fnames_presuffix(infiles, newpath=outdir) - if not isinstance(files, list): - newfiles = list_to_filename(newfiles) - setattr(self.inputs, info['key'], newfiles) - if execute and linksonly: - rmtree(outdir) + newfiles = fnames_presuffix(infiles, + newpath=outdir) + newfiles = _strip_temp( + newfiles, + op.abspath(olddir).split(op.sep)[-1]) + else: + newfiles = copyfiles(infiles, + [outdir], + copy=info['copy'], + create_new=True) + else: + newfiles = fnames_presuffix(infiles, newpath=outdir) + if not isinstance(files, list): + newfiles = list_to_filename(newfiles) + setattr(self.inputs, info['key'], newfiles) + if execute and linksonly: + emptydirs(outdir, noexist_ok=True) def update(self, **opts): + """Update inputs""" self.inputs.update(**opts) - def write_report(self, report_type=None, cwd=None): - if not str2bool(self.config['execution']['create_report']): - return - report_dir = op.join(cwd, '_report') - report_file = op.join(report_dir, 'report.rst') - if not op.exists(report_dir): - os.makedirs(report_dir) - if report_type == 'preexec': - logger.debug('writing pre-exec report to %s', report_file) - fp = open(report_file, 'wt') - fp.writelines(write_rst_header('Node: %s' % get_print_name(self), - level=0)) - fp.writelines(write_rst_list(['Hierarchy : %s' % self.fullname, - 'Exec ID : %s' % self._id])) - fp.writelines(write_rst_header('Original Inputs', level=1)) - fp.writelines(write_rst_dict(self.inputs.get())) - if report_type == 'postexec': - logger.debug('writing post-exec report to %s', report_file) - fp = open(report_file, 'at') - fp.writelines(write_rst_header('Execution Inputs', level=1)) - fp.writelines(write_rst_dict(self.inputs.get())) - exit_now = (not hasattr(self.result, 'outputs') or - self.result.outputs is None) - if exit_now: - return - fp.writelines(write_rst_header('Execution Outputs', level=1)) - if isinstance(self.result.outputs, Bunch): - fp.writelines(write_rst_dict(self.result.outputs.dictcopy())) - elif self.result.outputs: - fp.writelines(write_rst_dict(self.result.outputs.get())) - if isinstance(self, MapNode): - fp.close() - return - fp.writelines(write_rst_header('Runtime info', level=1)) - # Init rst dictionary of runtime stats - rst_dict = {'hostname': self.result.runtime.hostname, - 'duration': self.result.runtime.duration} - # Try and insert memory/threads usage if available - if config.resource_monitor: - rst_dict['mem_peak_gb'] = self.result.runtime.mem_peak_gb - rst_dict['cpu_percent'] = self.result.runtime.cpu_percent - - if hasattr(self.result.runtime, 'cmdline'): - rst_dict['command'] = self.result.runtime.cmdline - fp.writelines(write_rst_dict(rst_dict)) - else: - fp.writelines(write_rst_dict(rst_dict)) - if hasattr(self.result.runtime, 'merged'): - fp.writelines(write_rst_header('Terminal output', level=2)) - fp.writelines(write_rst_list(self.result.runtime.merged)) - if hasattr(self.result.runtime, 'environ'): - fp.writelines(write_rst_header('Environment', level=2)) - fp.writelines(write_rst_dict(self.result.runtime.environ)) - fp.close() - class JoinNode(Node): """Wraps interface objects that join inputs into a list. @@ -832,7 +717,8 @@ def __init__(self, interface, name, joinsource, joinfield=None, """ super(JoinNode, self).__init__(interface, name, **kwargs) - self.joinsource = joinsource + self._joinsource = None # The member should be defined + self.joinsource = joinsource # Let the setter do the job """the join predecessor iterable node""" if not joinfield: @@ -907,7 +793,7 @@ def _add_join_item_field(self, field, index): Return the new field name """ # the new field name - name = self._join_item_field_name(field, index) + name = "%sJ%d" % (field, index + 1) # make a copy of the join trait trait = self._inputs.trait(field, False, True) # add the join item trait to the override traits @@ -915,10 +801,6 @@ def _add_join_item_field(self, field, index): return name - def _join_item_field_name(self, field, index): - """Return the field suffixed by the index + 1""" - return "%sJ%d" % (field, index + 1) - def _override_join_traits(self, basetraits, fields): """Convert the given join fields to accept an input that is a list item rather than a list. Non-join fields @@ -967,7 +849,8 @@ def _collate_join_field_inputs(self): try: setattr(self._interface.inputs, field, val) except Exception as e: - raise ValueError(">>JN %s %s %s %s %s: %s" % (self, field, val, self.inputs.copyable_trait_names(), self.joinfield, e)) + raise ValueError(">>JN %s %s %s %s %s: %s" % ( + self, field, val, self.inputs.copyable_trait_names(), self.joinfield, e)) elif hasattr(self._interface.inputs, field): # copy the non-join field val = getattr(self._inputs, field) @@ -993,13 +876,14 @@ def _collate_input_value(self, field): basetrait = self._interface.inputs.trait(field) if isinstance(basetrait.trait_type, traits.Set): return set(val) - elif self._unique: + + if self._unique: return list(OrderedDict.fromkeys(val)) - else: - return val + + return val def _slot_value(self, field, index): - slot_field = self._join_item_field_name(field, index) + slot_field = "%sJ%d" % (field, index + 1) try: return getattr(self._inputs, slot_field) except AttributeError as e: @@ -1039,10 +923,13 @@ def __init__(self, interface, iterfield, name, serial=False, nested=False, **kwa name : alphanumeric string node specific name serial : boolean - flag to enforce executing the jobs of the mapnode in a serial manner rather than parallel - nested : boolea - support for nested lists, if set the input list will be flattened before running, and the - nested list structure of the outputs will be resored + flag to enforce executing the jobs of the mapnode in a serial + manner rather than parallel + nested : boolean + support for nested lists. If set, the input list will be flattened + before running and the nested list structure of the outputs will + be resored. + See Node docstring for additional keyword arguments. """ @@ -1080,15 +967,15 @@ def _create_dynamic_traits(self, basetraits, fields=None, nitems=None): return output def set_input(self, parameter, val): - """ Set interface input value or nodewrapper attribute - + """ + Set interface input value or nodewrapper attribute Priority goes to interface. """ logger.debug('setting nodelevel(%s) input %s = %s', to_str(self), parameter, to_str(val)) - self._set_mapnode_input(self.inputs, parameter, deepcopy(val)) + self._set_mapnode_input(parameter, deepcopy(val)) - def _set_mapnode_input(self, object, name, newvalue): + def _set_mapnode_input(self, name, newvalue): logger.debug('setting mapnode(%s) input: %s -> %s', to_str(self), name, to_str(newvalue)) if name in self.iterfield: @@ -1097,10 +984,8 @@ def _set_mapnode_input(self, object, name, newvalue): setattr(self._interface.inputs, name, newvalue) def _get_hashval(self): - """ Compute hash including iterfield lists.""" - if not self._got_inputs: - self._get_inputs() - self._got_inputs = True + """Compute hash including iterfield lists.""" + self._get_inputs() self._check_iterfield() hashinputs = deepcopy(self._interface.inputs) for name in self.iterfield: @@ -1135,8 +1020,6 @@ def inputs(self): def outputs(self): if self._interface._outputs(): return Bunch(self._interface._outputs().get()) - else: - return None def _make_nodes(self, cwd=None): if cwd is None: @@ -1146,7 +1029,7 @@ def _make_nodes(self, cwd=None): else: nitems = len(filename_to_list(getattr(self.inputs, self.iterfield[0]))) for i in range(nitems): - nodename = '_' + self.name + str(i) + nodename = '_%s%d' % (self.name, i) node = Node(deepcopy(self._interface), n_procs=self._n_procs, mem_gb=self._mem_gb, @@ -1156,8 +1039,9 @@ def _make_nodes(self, cwd=None): base_dir=op.join(cwd, 'mapflow'), name=nodename) node.plugin_args = self.plugin_args - node._interface.inputs.trait_set( + node.interface.inputs.trait_set( **deepcopy(self._interface.inputs.get())) + node.interface.resource_monitor = self._interface.resource_monitor for field in self.iterfield: if self.nested: fieldvals = flatten(filename_to_list(getattr(self.inputs, field))) @@ -1168,35 +1052,23 @@ def _make_nodes(self, cwd=None): node.config = self.config yield i, node - def _node_runner(self, nodes, updatehash=False): - old_cwd = os.getcwd() - for i, node in nodes: - err = None - try: - node.run(updatehash=updatehash) - except Exception as this_err: - err = this_err - if str2bool(self.config['execution']['stop_on_first_crash']): - raise - finally: - os.chdir(old_cwd) - yield i, node, err - def _collate_results(self, nodes): - self._result = InterfaceResult(interface=[], runtime=[], - provenance=[], inputs=[], - outputs=self.outputs) + finalresult = InterfaceResult( + interface=[], runtime=[], provenance=[], inputs=[], + outputs=self.outputs) returncode = [] - for i, node, err in nodes: - self._result.runtime.insert(i, None) - if node.result: - if hasattr(node.result, 'runtime'): - self._result.interface.insert(i, node.result.interface) - self._result.inputs.insert(i, node.result.inputs) - self._result.runtime[i] = node.result.runtime - if hasattr(node.result, 'provenance'): - self._result.provenance.insert(i, node.result.provenance) + for i, nresult, err in nodes: + finalresult.runtime.insert(i, None) returncode.insert(i, err) + + if nresult: + if hasattr(nresult, 'runtime'): + finalresult.interface.insert(i, nresult.interface) + finalresult.inputs.insert(i, nresult.inputs) + finalresult.runtime[i] = nresult.runtime + if hasattr(nresult, 'provenance'): + finalresult.provenance.insert(i, nresult.provenance) + if self.outputs: for key, _ in list(self.outputs.items()): rm_extra = (self.config['execution'] @@ -1204,78 +1076,52 @@ def _collate_results(self, nodes): if str2bool(rm_extra) and self.needed_outputs: if key not in self.needed_outputs: continue - values = getattr(self._result.outputs, key) + values = getattr(finalresult.outputs, key) if not isdefined(values): values = [] - if node.result.outputs: - values.insert(i, node.result.outputs.get()[key]) + if nresult and nresult.outputs: + values.insert(i, nresult.outputs.get()[key]) else: values.insert(i, None) defined_vals = [isdefined(val) for val in values] - if any(defined_vals) and self._result.outputs: - setattr(self._result.outputs, key, values) + if any(defined_vals) and finalresult.outputs: + setattr(finalresult.outputs, key, values) if self.nested: for key, _ in list(self.outputs.items()): - values = getattr(self._result.outputs, key) + values = getattr(finalresult.outputs, key) if isdefined(values): - values = unflatten(values, filename_to_list(getattr(self.inputs, self.iterfield[0]))) - setattr(self._result.outputs, key, values) + values = unflatten(values, filename_to_list( + getattr(self.inputs, self.iterfield[0]))) + setattr(finalresult.outputs, key, values) if returncode and any([code is not None for code in returncode]): msg = [] for i, code in enumerate(returncode): if code is not None: msg += ['Subnode %d failed' % i] - msg += ['Error:', str(code)] + msg += ['Error: %s' % str(code)] raise Exception('Subnodes of node: %s failed:\n%s' % (self.name, '\n'.join(msg))) - def write_report(self, report_type=None, cwd=None): - if not str2bool(self.config['execution']['create_report']): - return - if report_type == 'preexec': - super(MapNode, self).write_report(report_type=report_type, cwd=cwd) - if report_type == 'postexec': - super(MapNode, self).write_report(report_type=report_type, cwd=cwd) - report_dir = op.join(cwd, '_report') - report_file = op.join(report_dir, 'report.rst') - fp = open(report_file, 'at') - fp.writelines(write_rst_header('Subnode reports', level=1)) - nitems = len(filename_to_list( - getattr(self.inputs, self.iterfield[0]))) - subnode_report_files = [] - for i in range(nitems): - nodename = '_' + self.name + str(i) - subnode_report_files.insert(i, 'subnode %d' % i + ' : ' + - op.join(cwd, - 'mapflow', - nodename, - '_report', - 'report.rst')) - fp.writelines(write_rst_list(subnode_report_files)) - fp.close() + return finalresult def get_subnodes(self): - if not self._got_inputs: - self._get_inputs() - self._got_inputs = True + """Generate subnodes of a mapnode and write pre-execution report""" + self._get_inputs() self._check_iterfield() - self.write_report(report_type='preexec', cwd=self.output_dir()) + write_report(self, report_type='preexec', is_mapnode=True) return [node for _, node in self._make_nodes()] def num_subnodes(self): - if not self._got_inputs: - self._get_inputs() - self._got_inputs = True + """Get the number of subnodes to iterate in this MapNode""" + self._get_inputs() self._check_iterfield() if self._serial: return 1 - else: - if self.nested: - return len(filename_to_list(flatten(getattr(self.inputs, self.iterfield[0])))) - else: - return len(filename_to_list(getattr(self.inputs, self.iterfield[0]))) + if self.nested: + return len(filename_to_list(flatten(getattr(self.inputs, self.iterfield[0])))) + return len(filename_to_list(getattr(self.inputs, self.iterfield[0]))) def _get_inputs(self): old_inputs = self._inputs.get() @@ -1310,29 +1156,37 @@ def _run_interface(self, execute=True, updatehash=False): This is primarily intended for serial execution of mapnode. A parallel execution requires creation of new nodes that can be spawned """ - old_cwd = os.getcwd() - cwd = self.output_dir() - os.chdir(cwd) self._check_iterfield() - if execute: - if self.nested: - nitems = len(filename_to_list(flatten(getattr(self.inputs, - self.iterfield[0])))) - else: - nitems = len(filename_to_list(getattr(self.inputs, - self.iterfield[0]))) - nodenames = ['_' + self.name + str(i) for i in range(nitems)] - self._collate_results(self._node_runner(self._make_nodes(cwd), - updatehash=updatehash)) - self._save_results(self._result, cwd) - # remove any node directories no longer required - dirs2remove = [] - for path in glob(op.join(cwd, 'mapflow', '*')): - if op.isdir(path): - if path.split(op.sep)[-1] not in nodenames: - dirs2remove.append(path) - for path in dirs2remove: - shutil.rmtree(path) + cwd = self.output_dir() + if not execute: + return self._load_results() + + # Set up mapnode folder names + if self.nested: + nitems = len(filename_to_list(flatten(getattr(self.inputs, + self.iterfield[0])))) else: - self._result = self._load_results(cwd) - os.chdir(old_cwd) + nitems = len(filename_to_list(getattr(self.inputs, + self.iterfield[0]))) + nnametpl = '_%s{}' % self.name + nodenames = [nnametpl.format(i) for i in range(nitems)] + + # Run mapnode + result = self._collate_results(_node_runner( + self._make_nodes(cwd), + updatehash=updatehash, + stop_first=str2bool(self.config['execution']['stop_on_first_crash']) + )) + # And store results + _save_resultfile(result, cwd, self.name) + # remove any node directories no longer required + dirs2remove = [] + for path in glob(op.join(cwd, 'mapflow', '*')): + if op.isdir(path): + if path.split(op.sep)[-1] not in nodenames: + dirs2remove.append(path) + for path in dirs2remove: + logger.debug('[MapNode] Removing folder "%s".' , path) + shutil.rmtree(path) + + return result diff --git a/nipype/pipeline/engine/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py index 8b4d559ec0..034174758a 100644 --- a/nipype/pipeline/engine/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -488,7 +488,6 @@ def func1(in1): name='n1') n1.inputs.in1 = [[1, [2]], 3, [4, 5]] n1.run() - print(n1.get_output('out')) assert n1.get_output('out') == [[2, [3]], 4, [5, 6]] n2 = MapNode(Function(input_names=['in1'], diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index 96ba23cd3d..61937faac3 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1,42 +1,50 @@ # -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: -"""Utility routines for workflow graphs -""" +"""Utility routines for workflow graphs""" from __future__ import print_function, division, unicode_literals, absolute_import -from builtins import str, open, map, next, zip, range +from builtins import str, open, next, zip, range +import os import sys -from future import standard_library -standard_library.install_aliases() +import pickle from collections import defaultdict - +import re from copy import deepcopy from glob import glob -try: - from inspect import signature -except ImportError: - from funcsigs import signature -import os -import re -import pickle +from traceback import format_exception +from hashlib import sha1 +import gzip + from functools import reduce -import numpy as np -from distutils.version import LooseVersion +import numpy as np import networkx as nx +from future import standard_library -from ...utils.filemanip import (fname_presuffix, FileNotFoundError, to_str, - filename_to_list, get_related_files) +from ... import logging, config, LooseVersion +from ...utils.filemanip import ( + relpath, makedirs, fname_presuffix, to_str, + filename_to_list, get_related_files, FileNotFoundError, + save_json, savepkl, + write_rst_header, write_rst_dict, write_rst_list, +) from ...utils.misc import str2bool from ...utils.functions import create_function_from_source -from ...interfaces.base import (CommandLine, isdefined, Undefined, - InterfaceResult) +from ...interfaces.base import ( + Bunch, CommandLine, isdefined, Undefined, + InterfaceResult, traits) from ...interfaces.utility import IdentityInterface from ...utils.provenance import ProvStore, pm, nipype_ns, get_id -from ... import logging, config + +try: + from inspect import signature +except ImportError: + from funcsigs import signature + +standard_library.install_aliases() logger = logging.getLogger('workflow') PY3 = sys.version_info[0] > 2 @@ -46,39 +54,265 @@ dfs_preorder = nx.dfs_preorder_nodes logger.debug('networkx 1.4 dev or higher detected') -try: - from os.path import relpath -except ImportError: - import os.path as op - - def relpath(path, start=None): - """Return a relative version of a path""" - if start is None: - start = os.curdir - if not path: - raise ValueError("no path specified") - start_list = op.abspath(start).split(op.sep) - path_list = op.abspath(path).split(op.sep) - if start_list[0].lower() != path_list[0].lower(): - unc_path, rest = op.splitunc(path) - unc_start, rest = op.splitunc(start) - if bool(unc_path) ^ bool(unc_start): - raise ValueError(("Cannot mix UNC and non-UNC paths " - "(%s and %s)") % (path, start)) - else: - raise ValueError("path is on drive %s, start on drive %s" - % (path_list[0], start_list[0])) - # Work out how much of the filepath is shared by start and path. - for i in range(min(len(start_list), len(path_list))): - if start_list[i].lower() != path_list[i].lower(): - break + +def _parameterization_dir(param): + """ + Returns the directory name for the given parameterization string as follows: + - If the parameterization is longer than 32 characters, then + return the SHA-1 hex digest. + - Otherwise, return the parameterization unchanged. + """ + if len(param) > 32: + return sha1(param.encode()).hexdigest() + return param + + +def save_hashfile(hashfile, hashed_inputs): + """Store a hashfile""" + try: + save_json(hashfile, hashed_inputs) + except (IOError, TypeError): + err_type = sys.exc_info()[0] + if err_type is TypeError: + # XXX - SG current workaround is to just + # create the hashed file and not put anything + # in it + with open(hashfile, 'wt') as fd: + fd.writelines(str(hashed_inputs)) + + logger.debug( + 'Unable to write a particular type to the json file') else: - i += 1 + logger.critical('Unable to open the file in write mode: %s', + hashfile) + + +def nodelist_runner(nodes, updatehash=False, stop_first=False): + """ + A generator that iterates and over a list of ``nodes`` and + executes them. + + """ + for i, node in nodes: + err = None + result = None + try: + result = node.run(updatehash=updatehash) + except Exception: + if stop_first: + raise + + result = node.result + err = [] + if result.runtime and hasattr(result.runtime, 'traceback'): + err = [result.runtime.traceback] + + err += format_exception(*sys.exc_info()) + err = '\n'.join(err) + finally: + yield i, result, err + + +def write_report(node, report_type=None, is_mapnode=False): + """Write a report file for a node""" + if not str2bool(node.config['execution']['create_report']): + return + + if report_type not in ['preexec', 'postexec']: + logger.warning('[Node] Unknown report type "%s".', report_type) + return + + cwd = node.output_dir() + report_dir = os.path.join(cwd, '_report') + report_file = os.path.join(report_dir, 'report.rst') + makedirs(report_dir, exist_ok=True) + + logger.debug('[Node] Writing %s-exec report to "%s"', + report_type[:-4], report_file) + if report_type.startswith('pre'): + lines = [ + write_rst_header('Node: %s' % get_print_name(node), level=0), + write_rst_list(['Hierarchy : %s' % node.fullname, + 'Exec ID : %s' % node._id]), + write_rst_header('Original Inputs', level=1), + write_rst_dict(node.inputs.get()), + ] + with open(report_file, 'wt') as fp: + fp.write('\n'.join(lines)) + return - rel_list = [op.pardir] * (len(start_list) - i) + path_list[i:] - if not rel_list: - return os.curdir - return op.join(*rel_list) + lines = [ + write_rst_header('Execution Inputs', level=1), + write_rst_dict(node.inputs.get()), + ] + + result = node.result # Locally cache result + outputs = result.outputs + + if outputs is None: + with open(report_file, 'at') as fp: + fp.write('\n'.join(lines)) + return + + lines.append(write_rst_header('Execution Outputs', level=1)) + + if isinstance(outputs, Bunch): + lines.append(write_rst_dict(outputs.dictcopy())) + elif outputs: + lines.append(write_rst_dict(outputs.get())) + + if is_mapnode: + lines.append(write_rst_header('Subnode reports', level=1)) + nitems = len(filename_to_list( + getattr(node.inputs, node.iterfield[0]))) + subnode_report_files = [] + for i in range(nitems): + nodecwd = os.path.join( + cwd, 'mapflow', '_%s%d' % (node.name, i), + '_report', 'report.rst') + subnode_report_files.append( + 'subnode %d : %s' % (i, nodecwd)) + + lines.append(write_rst_list(subnode_report_files)) + + with open(report_file, 'at') as fp: + fp.write('\n'.join(lines)) + return + + lines.append(write_rst_header('Runtime info', level=1)) + # Init rst dictionary of runtime stats + rst_dict = { + 'hostname': result.runtime.hostname, + 'duration': result.runtime.duration, + } + + if hasattr(result.runtime, 'cmdline'): + rst_dict['command'] = result.runtime.cmdline + + # Try and insert memory/threads usage if available + if hasattr(result.runtime, 'mem_peak_gb'): + rst_dict['mem_peak_gb'] = result.runtime.mem_peak_gb + + if hasattr(result.runtime, 'cpu_percent'): + rst_dict['cpu_percent'] = result.runtime.cpu_percent + + lines.append(write_rst_dict(rst_dict)) + + # Collect terminal output + if hasattr(result.runtime, 'merged'): + lines += [ + write_rst_header('Terminal output', level=2), + write_rst_list(result.runtime.merged), + ] + if hasattr(result.runtime, 'stdout'): + lines += [ + write_rst_header('Terminal - standard output', level=2), + write_rst_list(result.runtime.stdout), + ] + if hasattr(result.runtime, 'stderr'): + lines += [ + write_rst_header('Terminal - standard error', level=2), + write_rst_list(result.runtime.stderr), + ] + + # Store environment + if hasattr(result.runtime, 'environ'): + lines += [ + write_rst_header('Environment', level=2), + write_rst_dict(result.runtime.environ), + ] + + with open(report_file, 'at') as fp: + fp.write('\n'.join(lines)) + return + + +def save_resultfile(result, cwd, name): + """Save a result pklz file to ``cwd``""" + resultsfile = os.path.join(cwd, 'result_%s.pklz' % name) + if result.outputs: + try: + outputs = result.outputs.get() + except TypeError: + outputs = result.outputs.dictcopy() # outputs was a bunch + result.outputs.set(**modify_paths( + outputs, relative=True, basedir=cwd)) + + savepkl(resultsfile, result) + logger.debug('saved results in %s', resultsfile) + + if result.outputs: + result.outputs.set(**outputs) + + +def load_resultfile(path, name): + """ + Load InterfaceResult file from path + + Parameter + --------- + + path : base_dir of node + name : name of node + + Returns + ------- + + result : InterfaceResult structure + aggregate : boolean indicating whether node should aggregate_outputs + attribute error : boolean indicating whether there was some mismatch in + versions of traits used to store result and hence node needs to + rerun + """ + aggregate = True + resultsoutputfile = os.path.join(path, 'result_%s.pklz' % name) + result = None + attribute_error = False + if os.path.exists(resultsoutputfile): + pkl_file = gzip.open(resultsoutputfile, 'rb') + try: + result = pickle.load(pkl_file) + except UnicodeDecodeError: + # Was this pickle created with Python 2.x? + pickle.load(pkl_file, fix_imports=True, encoding='utf-8') + logger.warning('Successfully loaded pickle in compatibility mode') + except (traits.TraitError, AttributeError, ImportError, + EOFError) as err: + if isinstance(err, (AttributeError, ImportError)): + attribute_error = True + logger.debug('attribute error: %s probably using ' + 'different trait pickled file', str(err)) + else: + logger.debug( + 'some file does not exist. hence trait cannot be set') + else: + if result.outputs: + try: + outputs = result.outputs.get() + except TypeError: + outputs = result.outputs.dictcopy() # outputs == Bunch + try: + result.outputs.set(**modify_paths(outputs, + relative=False, + basedir=path)) + except FileNotFoundError: + logger.debug('conversion to full path results in ' + 'non existent file') + aggregate = False + pkl_file.close() + logger.debug('Aggregate: %s', aggregate) + return result, aggregate, attribute_error + + +def strip_temp(files, wd): + """Remove temp from a list of file paths""" + out = [] + for f in files: + if isinstance(f, list): + out.append(strip_temp(f, wd)) + else: + out.append(f.replace(os.path.join(wd, '_tempinput'), wd)) + return out def _write_inputs(node): @@ -87,10 +321,10 @@ def _write_inputs(node): for key, _ in list(node.inputs.items()): val = getattr(node.inputs, key) if isdefined(val): - if type(val) == str: + if isinstance(val, (str, bytes)): try: func = create_function_from_source(val) - except RuntimeError as e: + except RuntimeError: lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) else: funcname = [name for name in func.__globals__ @@ -115,18 +349,18 @@ def format_node(node, format='python', include_config=False): lines = [] name = node.fullname.replace('.', '_') if format == 'python': - klass = node._interface + klass = node.interface importline = 'from %s import %s' % (klass.__module__, klass.__class__.__name__) comment = '# Node: %s' % node.fullname - spec = signature(node._interface.__init__) + spec = signature(node.interface.__init__) args = [p.name for p in list(spec.parameters.values())] args = args[1:] if args: filled_args = [] for arg in args: - if hasattr(node._interface, '_%s' % arg): - filled_args.append('%s=%s' % (arg, getattr(node._interface, + if hasattr(node.interface, '_%s' % arg): + filled_args.append('%s=%s' % (arg, getattr(node.interface, '_%s' % arg))) args = ', '.join(filled_args) else: @@ -195,7 +429,7 @@ def modify_paths(object, relative=True, basedir=None): else: out = os.path.abspath(os.path.join(basedir, object)) if not os.path.exists(out): - raise FileNotFoundError('File %s not found' % out) + raise IOError('File %s not found' % out) else: out = object return out @@ -210,8 +444,8 @@ def get_print_name(node, simple_form=True): """ name = node.fullname if hasattr(node, '_interface'): - pkglist = node._interface.__class__.__module__.split('.') - interface = node._interface.__class__.__name__ + pkglist = node.interface.__class__.__module__.split('.') + interface = node.interface.__class__.__name__ destclass = '' if len(pkglist) > 2: destclass = '.%s' % pkglist[2] @@ -247,23 +481,22 @@ def _create_dot_graph(graph, show_connectinfo=False, simple_form=True): def _write_detailed_dot(graph, dotfilename): - """Create a dot file with connection info - - digraph structs { - node [shape=record]; - struct1 [label=" left| mid\ dle| right"]; - struct2 [label=" one| two"]; - struct3 [label="hello\nworld |{ b |{c| d|e}| f}| g | h"]; - struct1:f1 -> struct2:f0; - struct1:f0 -> struct2:f1; - struct1:f2 -> struct3:here; - } + r""" + Create a dot file with connection info :: + + digraph structs { + node [shape=record]; + struct1 [label=" left| middle| right"]; + struct2 [label=" one| two"]; + struct3 [label="hello\nworld |{ b |{c| d|e}| f}| g | h"]; + struct1:f1 -> struct2:f0; + struct1:f0 -> struct2:f1; + struct1:f2 -> struct3:here; + } """ text = ['digraph structs {', 'node [shape=record];'] # write nodes edges = [] - replacefunk = lambda x: x.replace('_', '').replace('.', ''). \ - replace('@', '').replace('-', '') for n in nx.topological_sort(graph): nodename = str(n) inports = [] @@ -274,18 +507,16 @@ def _write_detailed_dot(graph, dotfilename): else: outport = cd[0][0] inport = cd[1] - ipstrip = 'in' + replacefunk(inport) - opstrip = 'out' + replacefunk(outport) + ipstrip = 'in%s' % _replacefunk(inport) + opstrip = 'out%s' % _replacefunk(outport) edges.append('%s:%s:e -> %s:%s:w;' % (str(u).replace('.', ''), opstrip, str(v).replace('.', ''), ipstrip)) if inport not in inports: inports.append(inport) - inputstr = '{IN' - for ip in sorted(inports): - inputstr += '| %s' % (replacefunk(ip), ip) - inputstr += '}' + inputstr = ['{IN'] + ['| %s' % (_replacefunk(ip), ip) + for ip in sorted(inports)] + ['}'] outports = [] for u, v, d in graph.out_edges(nbunch=n, data=True): for cd in d['connect']: @@ -295,13 +526,11 @@ def _write_detailed_dot(graph, dotfilename): outport = cd[0][0] if outport not in outports: outports.append(outport) - outputstr = '{OUT' - for op in sorted(outports): - outputstr += '| %s' % (replacefunk(op), op) - outputstr += '}' + outputstr = ['{OUT'] + ['| %s' % (_replacefunk(oport), oport) + for oport in sorted(outports)] + ['}'] srcpackage = '' if hasattr(n, '_interface'): - pkglist = n._interface.__class__.__module__.split('.') + pkglist = n.interface.__class__.__module__.split('.') if len(pkglist) > 2: srcpackage = pkglist[2] srchierarchy = '.'.join(nodename.split('.')[1:-1]) @@ -309,19 +538,23 @@ def _write_detailed_dot(graph, dotfilename): srcpackage, srchierarchy) text += ['%s [label="%s|%s|%s"];' % (nodename.replace('.', ''), - inputstr, + ''.join(inputstr), nodenamestr, - outputstr)] + ''.join(outputstr))] # write edges for edge in sorted(edges): text.append(edge) text.append('}') - filep = open(dotfilename, 'wt') - filep.write('\n'.join(text)) - filep.close() + with open(dotfilename, 'wt') as filep: + filep.write('\n'.join(text)) return text +def _replacefunk(x): + return x.replace('_', '').replace( + '.', '').replace('@', '').replace('-', '') + + # Graph manipulations for iterable expansion def _get_valid_pathstr(pathstr): """Remove disallowed characters from path @@ -340,8 +573,7 @@ def _get_valid_pathstr(pathstr): def expand_iterables(iterables, synchronize=False): if synchronize: return synchronize_iterables(iterables) - else: - return list(walk(list(iterables.items()))) + return list(walk(list(iterables.items()))) def count_iterables(iterables, synchronize=False): @@ -352,10 +584,7 @@ def count_iterables(iterables, synchronize=False): Otherwise, the count is the product of the iterables value list sizes. """ - if synchronize: - op = max - else: - op = lambda x, y: x * y + op = max if synchronize else lambda x, y: x * y return reduce(op, [len(func()) for _, func in list(iterables.items())]) @@ -524,14 +753,13 @@ def _merge_graphs(supergraph, nodes, subgraph, nodeid, iterables, logger.debug('Parameterization: paramstr=%s', paramstr) levels = get_levels(Gc) for n in Gc.nodes(): - """ - update parameterization of the node to reflect the location of - the output directory. For example, if the iterables along a - path of the directed graph consisted of the variables 'a' and - 'b', then every node in the path including and after the node - with iterable 'b' will be placed in a directory - _a_aval/_b_bval/. - """ + # update parameterization of the node to reflect the location of + # the output directory. For example, if the iterables along a + # path of the directed graph consisted of the variables 'a' and + # 'b', then every node in the path including and after the node + # with iterable 'b' will be placed in a directory + # _a_aval/_b_bval/. + path_length = levels[n] # enter as negative numbers so that earlier iterables with longer # path lengths get precedence in a sort @@ -583,7 +811,7 @@ def _identity_nodes(graph, include_iterables): to True. """ return [node for node in nx.topological_sort(graph) - if isinstance(node._interface, IdentityInterface) and + if isinstance(node.interface, IdentityInterface) and (include_iterables or getattr(node, 'iterables') is None)] @@ -598,7 +826,7 @@ def _remove_identity_node(graph, node): else: _propagate_root_output(graph, node, field, connections) graph.remove_nodes_from([node]) - logger.debug("Removed the identity node %s from the graph." % node) + logger.debug("Removed the identity node %s from the graph.", node) def _node_ports(graph, node): @@ -689,12 +917,12 @@ def generate_expanded_graph(graph_in): # the iterable nodes inodes = _iterable_nodes(graph_in) - logger.debug("Detected iterable nodes %s" % inodes) + logger.debug("Detected iterable nodes %s", inodes) # while there is an iterable node, expand the iterable node's # subgraphs while inodes: inode = inodes[0] - logger.debug("Expanding the iterable node %s..." % inode) + logger.debug("Expanding the iterable node %s...", inode) # the join successor nodes of the current iterable node jnodes = [node for node in graph_in.nodes() @@ -715,8 +943,8 @@ def generate_expanded_graph(graph_in): for src, dest in edges2remove: graph_in.remove_edge(src, dest) - logger.debug("Excised the %s -> %s join node in-edge." - % (src, dest)) + logger.debug("Excised the %s -> %s join node in-edge.", + src, dest) if inode.itersource: # the itersource is a (node name, fields) tuple @@ -733,8 +961,8 @@ def generate_expanded_graph(graph_in): raise ValueError("The node %s itersource %s was not found" " among the iterable predecessor nodes" % (inode, src_name)) - logger.debug("The node %s has iterable source node %s" - % (inode, iter_src)) + logger.debug("The node %s has iterable source node %s", + inode, iter_src) # look up the iterables for this particular itersource descendant # using the iterable source ancestor values as a key iterables = {} @@ -760,7 +988,7 @@ def make_field_func(*pair): else: iterables = inode.iterables.copy() inode.iterables = None - logger.debug('node: %s iterables: %s' % (inode, iterables)) + logger.debug('node: %s iterables: %s', inode, iterables) # collect the subnodes to expand subnodes = [s for s in dfs_preorder(graph_in, inode)] @@ -768,7 +996,7 @@ def make_field_func(*pair): for s in subnodes: prior_prefix.extend(re.findall('\.(.)I', s._id)) prior_prefix = sorted(prior_prefix) - if not len(prior_prefix): + if not prior_prefix: iterable_prefix = 'a' else: if prior_prefix[-1] == 'z': @@ -798,12 +1026,12 @@ def make_field_func(*pair): # the edge source node replicates expansions = defaultdict(list) for node in graph_in.nodes(): - for src_id, edge_data in list(old_edge_dict.items()): + for src_id in list(old_edge_dict.keys()): if node.itername.startswith(src_id): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug("The join node %s input %s was expanded" - " to %d nodes." % (jnode, in_id, len(in_nodes))) + " to %d nodes.", jnode, in_id, len(in_nodes)) # preserve the node iteration order by sorting on the node id for in_nodes in list(expansions.values()): in_nodes.sort(key=lambda node: node._id) @@ -843,12 +1071,12 @@ def make_field_func(*pair): if dest_field in slots: slot_field = slots[dest_field] connects[con_idx] = (src_field, slot_field) - logger.debug("Qualified the %s -> %s join field" - " %s as %s." % - (in_node, jnode, dest_field, slot_field)) + logger.debug( + "Qualified the %s -> %s join field %s as %s.", + in_node, jnode, dest_field, slot_field) graph_in.add_edge(in_node, jnode, **newdata) logger.debug("Connected the join node %s subgraph to the" - " expanded join point %s" % (jnode, in_node)) + " expanded join point %s", jnode, in_node) # nx.write_dot(graph_in, '%s_post.dot' % node) # the remaining iterable nodes @@ -904,7 +1132,7 @@ def _standardize_iterables(node): fields = set(node.inputs.copyable_trait_names()) # Flag indicating whether the iterables are in the alternate # synchronize form and are not converted to a standard format. - synchronize = False + # synchronize = False # OE: commented out since it is not used # A synchronize iterables node without an itersource can be in # [fields, value tuples] format rather than # [(field, value list), (field, value list), ...] @@ -979,9 +1207,9 @@ def _transpose_iterables(fields, values): if val is not None: transposed[fields[idx]][key].append(val) return list(transposed.items()) - else: - return list(zip(fields, [[v for v in list(transpose) if v is not None] - for transpose in zip(*values)])) + + return list(zip(fields, [[v for v in list(transpose) if v is not None] + for transpose in zip(*values)])) def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False, @@ -1015,8 +1243,8 @@ def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False, logger.debug('using input graph') if base_dir is None: base_dir = os.getcwd() - if not os.path.exists(base_dir): - os.makedirs(base_dir) + + makedirs(base_dir, exist_ok=True) outfname = fname_presuffix(dotfilename, suffix='_detailed.dot', use_ext=False, @@ -1027,7 +1255,7 @@ def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False, res = CommandLine(cmd, terminal_output='allatonce', resource_monitor=False).run() if res.runtime.returncode: - logger.warn('dot2png: %s', res.runtime.stderr) + logger.warning('dot2png: %s', res.runtime.stderr) pklgraph = _create_dot_graph(graph, show_connectinfo, simple_form) simplefname = fname_presuffix(dotfilename, suffix='.dot', @@ -1039,7 +1267,7 @@ def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False, res = CommandLine(cmd, terminal_output='allatonce', resource_monitor=False).run() if res.runtime.returncode: - logger.warn('dot2png: %s', res.runtime.stderr) + logger.warning('dot2png: %s', res.runtime.stderr) if show: pos = nx.graphviz_layout(pklgraph, prog='dot') nx.draw(pklgraph, pos) @@ -1067,26 +1295,6 @@ def format_dot(dotfilename, format='png'): return dotfilename -def make_output_dir(outdir): - """Make the output_dir if it doesn't exist. - - Parameters - ---------- - outdir : output directory to create - - """ - # this odd approach deals with concurrent directory cureation - try: - if not os.path.exists(os.path.abspath(outdir)): - logger.debug("Creating %s", outdir) - os.makedirs(outdir) - except OSError: - logger.debug("Problem creating %s", outdir) - if not os.path.exists(outdir): - raise OSError('Could not create %s' % outdir) - return outdir - - def get_all_files(infile): files = [infile] if infile.endswith(".img"): @@ -1102,7 +1310,7 @@ def walk_outputs(object): """ out = [] if isinstance(object, dict): - for key, val in sorted(object.items()): + for _, val in sorted(object.items()): if isdefined(val): out.extend(walk_outputs(val)) elif isinstance(object, (list, tuple)): @@ -1159,13 +1367,13 @@ def clean_working_directory(outputs, cwd, inputs, needed_outputs, config, for filename in needed_files: temp.extend(get_related_files(filename)) needed_files = temp - logger.debug('Needed files: %s' % (';'.join(needed_files))) - logger.debug('Needed dirs: %s' % (';'.join(needed_dirs))) + logger.debug('Needed files: %s', ';'.join(needed_files)) + logger.debug('Needed dirs: %s', ';'.join(needed_dirs)) files2remove = [] if str2bool(config['execution']['remove_unnecessary_outputs']): for f in walk_files(cwd): if f not in needed_files: - if len(needed_dirs) == 0: + if not needed_dirs: files2remove.append(f) elif not any([f.startswith(dname) for dname in needed_dirs]): files2remove.append(f) @@ -1178,7 +1386,7 @@ def clean_working_directory(outputs, cwd, inputs, needed_outputs, config, for f in walk_files(cwd): if f in input_files and f not in needed_files: files2remove.append(f) - logger.debug('Removing files: %s' % (';'.join(files2remove))) + logger.debug('Removing files: %s', ';'.join(files2remove)) for f in files2remove: os.remove(f) for key in outputs.copyable_trait_names(): @@ -1242,9 +1450,9 @@ def write_workflow_prov(graph, filename=None, format='all'): processes = [] nodes = graph.nodes() - for idx, node in enumerate(nodes): + for node in nodes: result = node.result - classname = node._interface.__class__.__name__ + classname = node.interface.__class__.__name__ _, hashval, _, _ = node.hash_exists() attrs = {pm.PROV["type"]: nipype_ns[classname], pm.PROV["label"]: '_'.join((classname, node.name)), @@ -1260,7 +1468,7 @@ def write_workflow_prov(graph, filename=None, format='all'): if idx < len(result.inputs): subresult.inputs = result.inputs[idx] if result.outputs: - for key, value in list(result.outputs.items()): + for key, _ in list(result.outputs.items()): values = getattr(result.outputs, key) if isdefined(values) and idx < len(values): subresult.outputs[key] = values[idx] @@ -1334,9 +1542,9 @@ def write_workflow_resources(graph, filename=None, append=None): with open(filename, 'r' if PY3 else 'rb') as rsf: big_dict = json.load(rsf) - for idx, node in enumerate(graph.nodes()): + for _, node in enumerate(graph.nodes()): nodename = node.fullname - classname = node._interface.__class__.__name__ + classname = node.interface.__class__.__name__ params = '' if node.parameterization: diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py index d58424fcd5..e00f105c5e 100644 --- a/nipype/pipeline/engine/workflows.py +++ b/nipype/pipeline/engine/workflows.py @@ -15,50 +15,38 @@ """ from __future__ import print_function, division, unicode_literals, absolute_import -from builtins import range, object, str, bytes, open - -# Py2 compat: http://python-future.org/compatible_idioms.html#collections-counter-and-ordereddict -from future import standard_library -standard_library.install_aliases() +from builtins import str, bytes, open +import os +import os.path as op +import sys from datetime import datetime - from copy import deepcopy import pickle -import os -import os.path as op import shutil -import sys -from warnings import warn import numpy as np import networkx as nx - from ... import config, logging - -from ...utils.misc import (unflatten, str2bool) +from ...utils.misc import str2bool from ...utils.functions import (getsource, create_function_from_source) -from ...interfaces.base import (traits, InputMultiPath, CommandLine, - Undefined, TraitedSpec, DynamicTraitedSpec, - Bunch, InterfaceResult, Interface, - TraitDictObject, TraitListObject, isdefined) - -from ...utils.filemanip import (save_json, FileNotFoundError, md5, - filename_to_list, list_to_filename, - copyfiles, fnames_presuffix, loadpkl, - split_filename, load_json, savepkl, - write_rst_header, write_rst_dict, - write_rst_list, to_str) -from .utils import (generate_expanded_graph, modify_paths, - export_graph, make_output_dir, write_workflow_prov, - write_workflow_resources, - clean_working_directory, format_dot, topological_sort, - get_print_name, merge_dict, evaluate_connect_function, - _write_inputs, format_node) + +from ...interfaces.base import ( + traits, TraitedSpec, TraitDictObject, TraitListObject) +from ...utils.filemanip import save_json, makedirs, to_str +from .utils import ( + generate_expanded_graph, export_graph, write_workflow_prov, + write_workflow_resources, format_dot, topological_sort, + get_print_name, merge_dict, format_node +) from .base import EngineBase -from .nodes import Node, MapNode +from .nodes import MapNode + +# Py2 compat: http://python-future.org/compatible_idioms.html#collections-counter-and-ordereddict +from future import standard_library +standard_library.install_aliases() logger = logging.getLogger('workflow') @@ -202,16 +190,16 @@ def connect(self, *args, **kwargs): connected. """ % (srcnode, source, destnode, dest, dest, destnode)) if not (hasattr(destnode, '_interface') and - ('.io' in str(destnode._interface.__class__) or - any(['.io' in str(val) for val in - destnode._interface.__class__.__bases__])) + ('.io' in str(destnode._interface.__class__) or + any(['.io' in str(val) for val in + destnode._interface.__class__.__bases__])) ): if not destnode._check_inputs(dest): not_found.append(['in', destnode.name, dest]) if not (hasattr(srcnode, '_interface') and - ('.io' in str(srcnode._interface.__class__) - or any(['.io' in str(val) for val in - srcnode._interface.__class__.__bases__]))): + ('.io' in str(srcnode._interface.__class__) or + any(['.io' in str(val) + for val in srcnode._interface.__class__.__bases__]))): if isinstance(source, tuple): # handles the case that source is specified # with a function @@ -425,7 +413,7 @@ def write_graph(self, dotfilename='graph.dot', graph2use='hierarchical', base_dir = op.join(base_dir, self.name) else: base_dir = os.getcwd() - base_dir = make_output_dir(base_dir) + base_dir = makedirs(base_dir, exist_ok=True) if graph2use in ['hierarchical', 'colored']: if self.name[:1].isdigit(): # these graphs break if int raise ValueError('{} graph failed, workflow name cannot begin ' @@ -571,12 +559,6 @@ def run(self, plugin=None, plugin_args=None, updatehash=False): runner = plugin_mod(plugin_args=plugin_args) flatgraph = self._create_flat_graph() self.config = merge_dict(deepcopy(config._sections), self.config) - if 'crashdump_dir' in self.config: - warn(("Deprecated: workflow.config['crashdump_dir']\n" - "Please use config['execution']['crashdump_dir']")) - crash_dir = self.config['crashdump_dir'] - self.config['execution']['crashdump_dir'] = crash_dir - del self.config['crashdump_dir'] logger.info('Workflow %s settings: %s', self.name, to_str(sorted(self.config))) self._set_needed_outputs(flatgraph) execgraph = generate_expanded_graph(deepcopy(flatgraph)) @@ -611,8 +593,7 @@ def _write_report_info(self, workingdir, name, graph): if workingdir is None: workingdir = os.getcwd() report_dir = op.join(workingdir, name) - if not op.exists(report_dir): - os.makedirs(report_dir) + makedirs(report_dir, exist_ok=True) shutil.copyfile(op.join(op.dirname(__file__), 'report_template.html'), op.join(report_dir, 'index.html')) @@ -930,13 +911,13 @@ def _get_dot(self, prefix=None, hierarchy=None, colored=False, prefix = ' ' if hierarchy is None: hierarchy = [] - colorset = ['#FFFFC8', # Y - '#0000FF', '#B4B4FF', '#E6E6FF', # B - '#FF0000', '#FFB4B4', '#FFE6E6', # R - '#00A300', '#B4FFB4', '#E6FFE6', # G - '#0000FF', '#B4B4FF'] # loop B + colorset = ['#FFFFC8', # Y + '#0000FF', '#B4B4FF', '#E6E6FF', # B + '#FF0000', '#FFB4B4', '#FFE6E6', # R + '#00A300', '#B4FFB4', '#E6FFE6', # G + '#0000FF', '#B4B4FF'] # loop B if level > len(colorset) - 2: - level = 3 # Loop back to blue + level = 3 # Loop back to blue dotlist = ['%slabel="%s";' % (prefix, self.name)] for node in nx.topological_sort(self._graph): diff --git a/nipype/pipeline/plugins/base.py b/nipype/pipeline/plugins/base.py index e27733ab77..ec8c68a148 100644 --- a/nipype/pipeline/plugins/base.py +++ b/nipype/pipeline/plugins/base.py @@ -408,7 +408,7 @@ def _remove_node_dirs(self): continue if self.proc_done[idx] and (not self.proc_pending[idx]): self.refidx[idx, idx] = -1 - outdir = self.procs[idx]._output_directory() + outdir = self.procs[idx].output_dir() logger.info(('[node dependencies finished] ' 'removing node: %s from directory %s') % (self.procs[idx]._id, outdir)) diff --git a/nipype/pipeline/plugins/tools.py b/nipype/pipeline/plugins/tools.py index 499a1db2d7..c07a8966b6 100644 --- a/nipype/pipeline/plugins/tools.py +++ b/nipype/pipeline/plugins/tools.py @@ -15,7 +15,7 @@ from traceback import format_exception from ... import logging -from ...utils.filemanip import savepkl, crash2txt +from ...utils.filemanip import savepkl, crash2txt, makedirs logger = logging.getLogger('workflow') @@ -42,8 +42,7 @@ def report_crash(node, traceback=None, hostname=None): timeofcrash, login_name, name, str(uuid.uuid4())) crashdir = node.config['execution'].get('crashdump_dir', os.getcwd()) - if not os.path.exists(crashdir): - os.makedirs(crashdir) + makedirs(crashdir, exist_ok=True) crashfile = os.path.join(crashdir, crashfile) if node.config['execution']['crashfile_format'].lower() in ['text', 'txt']: diff --git a/nipype/utils/config.py b/nipype/utils/config.py index 3c9218f2a6..c02be71f64 100644 --- a/nipype/utils/config.py +++ b/nipype/utils/config.py @@ -15,7 +15,6 @@ import errno import atexit from warnings import warn -from io import StringIO from distutils.version import LooseVersion import configparser import numpy as np @@ -37,21 +36,19 @@ NUMPY_MMAP = LooseVersion(np.__version__) >= LooseVersion('1.12.0') -# Get home directory in platform-agnostic way -homedir = os.path.expanduser('~') -default_cfg = """ +DEFAULT_CONFIG_TPL = """\ [logging] workflow_level = INFO utils_level = INFO interface_level = INFO log_to_file = false -log_directory = %s +log_directory = {log_dir} log_size = 16384000 log_rotate = 4 [execution] create_report = true -crashdump_dir = %s +crashdump_dir = {crashdump_dir} hash_method = timestamp job_finished_timeout = 5 keep_inputs = false @@ -79,7 +76,7 @@ [check] interval = 1209600 -""" % (homedir, os.getcwd()) +""".format def mkdir_p(path): @@ -97,15 +94,17 @@ class NipypeConfig(object): def __init__(self, *args, **kwargs): self._config = configparser.ConfigParser() + self._cwd = None + config_dir = os.path.expanduser('~/.nipype') - config_file = os.path.join(config_dir, 'nipype.cfg') self.data_file = os.path.join(config_dir, 'nipype.json') - self._config.readfp(StringIO(default_cfg)) + + self.set_default_config() self._display = None self._resource_monitor = None if os.path.exists(config_dir): - self._config.read([config_file, 'nipype.cfg']) + self._config.read([os.path.join(config_dir, 'nipype.cfg'), 'nipype.cfg']) for option in CONFIG_DEPRECATIONS: for section in ['execution', 'logging', 'monitoring']: @@ -115,8 +114,32 @@ def __init__(self, *args, **kwargs): # Warn implicit in get self.set(new_section, new_option, self.get(section, option)) + @property + def cwd(self): + """Cache current working directory ASAP""" + # Run getcwd only once, preventing multiproc to finish + # with error having changed to the wrong path + if self._cwd is None: + try: + self._cwd = os.getcwd() + except OSError: + warn('Trying to run Nipype from a nonexistent directory "%s".', + os.getenv('PWD', 'unknown')) + raise + return self._cwd + def set_default_config(self): - self._config.readfp(StringIO(default_cfg)) + """Read default settings template and set into config object""" + default_cfg = DEFAULT_CONFIG_TPL( + log_dir=os.path.expanduser('~'), # Get $HOME in a platform-agnostic way + crashdump_dir=self.cwd # Read cached cwd + ) + + try: + self._config.read_string(default_cfg) # Python >= 3.2 + except AttributeError: + from io import StringIO + self._config.readfp(StringIO(default_cfg)) def enable_debug_mode(self): """Enables debug configuration""" diff --git a/nipype/utils/filemanip.py b/nipype/utils/filemanip.py index d721b740a9..d87f498d00 100644 --- a/nipype/utils/filemanip.py +++ b/nipype/utils/filemanip.py @@ -8,12 +8,14 @@ import sys import pickle +import errno import subprocess as sp import gzip import hashlib import locale from hashlib import md5 import os +import os.path as op import re import shutil import posixpath @@ -75,8 +77,8 @@ def split_filename(fname): special_extensions = [".nii.gz", ".tar.gz"] - pth = os.path.dirname(fname) - fname = os.path.basename(fname) + pth = op.dirname(fname) + fname = op.basename(fname) ext = None for special_ext in special_extensions: @@ -87,7 +89,7 @@ def split_filename(fname): fname = fname[:-ext_len] break if not ext: - fname, ext = os.path.splitext(fname) + fname, ext = op.splitext(fname) return pth, fname, ext @@ -186,8 +188,8 @@ def fname_presuffix(fname, prefix='', suffix='', newpath=None, use_ext=True): # No need for isdefined: bool(Undefined) evaluates to False if newpath: - pth = os.path.abspath(newpath) - return os.path.join(pth, prefix + fname + suffix + ext) + pth = op.abspath(newpath) + return op.join(pth, prefix + fname + suffix + ext) def fnames_presuffix(fnames, prefix='', suffix='', newpath=None, use_ext=True): @@ -205,14 +207,14 @@ def hash_rename(filename, hashvalue): """ path, name, ext = split_filename(filename) newfilename = ''.join((name, '_0x', hashvalue, ext)) - return os.path.join(path, newfilename) + return op.join(path, newfilename) def check_forhash(filename): """checks if file has a hash in its filename""" if isinstance(filename, list): filename = filename[0] - path, name = os.path.split(filename) + path, name = op.split(filename) if re.search('(_0x[a-z0-9]{32})', name): hashvalue = re.findall('(_0x[a-z0-9]{32})', name) return True, hashvalue @@ -223,7 +225,7 @@ def check_forhash(filename): def hash_infile(afile, chunk_len=8192, crypto=hashlib.md5): """ Computes hash of a file using 'crypto' module""" hex = None - if os.path.isfile(afile): + if op.isfile(afile): crypto_obj = crypto() with open(afile, 'rb') as fp: while True: @@ -238,7 +240,7 @@ def hash_infile(afile, chunk_len=8192, crypto=hashlib.md5): def hash_timestamp(afile): """ Computes md5 hash of the timestamp of a file """ md5hex = None - if os.path.isfile(afile): + if op.isfile(afile): md5obj = md5() stat = os.stat(afile) md5obj.update(str(stat.st_size).encode()) @@ -332,7 +334,7 @@ def copyfile(originalfile, newfile, copy=False, create_new=False, fmlogger.debug(newfile) if create_new: - while os.path.exists(newfile): + while op.exists(newfile): base, fname, ext = split_filename(newfile) s = re.search('_c[0-9]{4,4}$', fname) i = 0 @@ -362,9 +364,9 @@ def copyfile(originalfile, newfile, copy=False, create_new=False, # copy of file (same hash) (keep) # different file (diff hash) (unlink) keep = False - if os.path.lexists(newfile): - if os.path.islink(newfile): - if all((os.readlink(newfile) == os.path.realpath(originalfile), + if op.lexists(newfile): + if op.islink(newfile): + if all((os.readlink(newfile) == op.realpath(originalfile), not use_hardlink, not copy)): keep = True elif posixpath.samefile(newfile, originalfile): @@ -394,7 +396,7 @@ def copyfile(originalfile, newfile, copy=False, create_new=False, try: fmlogger.debug('Linking File: %s->%s', newfile, originalfile) # Use realpath to avoid hardlinking symlinks - os.link(os.path.realpath(originalfile), newfile) + os.link(op.realpath(originalfile), newfile) except OSError: use_hardlink = False # Disable hardlink for associated files else: @@ -421,7 +423,7 @@ def copyfile(originalfile, newfile, copy=False, create_new=False, related_file_pairs = (get_related_files(f, include_this_file=False) for f in (originalfile, newfile)) for alt_ofile, alt_nfile in zip(*related_file_pairs): - if os.path.exists(alt_ofile): + if op.exists(alt_ofile): copyfile(alt_ofile, alt_nfile, copy, hashmethod=hashmethod, use_hardlink=use_hardlink, copy_related_files=False) @@ -446,7 +448,7 @@ def get_related_files(filename, include_this_file=True): if this_type in type_set: for related_type in type_set: if include_this_file or related_type != this_type: - related_files.append(os.path.join(path, name + related_type)) + related_files.append(op.join(path, name + related_type)) if not len(related_files): related_files = [filename] return related_files @@ -518,9 +520,9 @@ def check_depends(targets, dependencies): """ tgts = filename_to_list(targets) deps = filename_to_list(dependencies) - return all(map(os.path.exists, tgts)) and \ - min(map(os.path.getmtime, tgts)) > \ - max(list(map(os.path.getmtime, deps)) + [0]) + return all(map(op.exists, tgts)) and \ + min(map(op.getmtime, tgts)) > \ + max(list(map(op.getmtime, deps)) + [0]) def save_json(filename, data): @@ -667,12 +669,72 @@ def dist_is_editable(dist): # Borrowed from `pip`'s' API """ for path_item in sys.path: - egg_link = os.path.join(path_item, dist + '.egg-link') - if os.path.isfile(egg_link): + egg_link = op.join(path_item, dist + '.egg-link') + if op.isfile(egg_link): return True return False +def makedirs(path, exist_ok=False): + """ + Create path, if it doesn't exist. + + Parameters + ---------- + path : output directory to create + + """ + if not exist_ok: # The old makedirs + os.makedirs(path) + return path + + # this odd approach deals with concurrent directory cureation + if not op.exists(op.abspath(path)): + fmlogger.debug("Creating directory %s", path) + try: + os.makedirs(path) + except OSError: + fmlogger.debug("Problem creating directory %s", path) + if not op.exists(path): + raise OSError('Could not create directory %s' % path) + return path + + +def emptydirs(path, noexist_ok=False): + """ + Empty an existing directory, without deleting it. Do not + raise error if the path does not exist and noexist_ok is True. + + Parameters + ---------- + path : directory that should be empty + + """ + fmlogger.debug("Removing contents of %s", path) + + if noexist_ok and not op.exists(path): + return True + + if op.isfile(path): + raise OSError('path "%s" should be a directory' % path) + + try: + shutil.rmtree(path) + except OSError as ex: + elcont = os.listdir(path) + if ex.errno == errno.ENOTEMPTY and not elcont: + fmlogger.warning( + 'An exception was raised trying to remove old %s, but the path ' + 'seems empty. Is it an NFS mount?. Passing the exception.', path) + elif ex.errno == errno.ENOTEMPTY and elcont: + fmlogger.debug('Folder %s contents (%d items).', path, len(elcont)) + raise ex + else: + raise ex + + makedirs(path) + + def which(cmd, env=None, pathext=None): """ Return the path to an executable which would be run if the given @@ -701,8 +763,8 @@ def which(cmd, env=None, pathext=None): for ext in pathext: extcmd = cmd + ext for directory in path.split(os.pathsep): - filename = os.path.join(directory, extcmd) - if os.path.exists(filename): + filename = op.join(directory, extcmd) + if op.exists(filename): return filename return None @@ -758,3 +820,39 @@ def canonicalize_env(env): val = val.encode('utf-8') out_env[key] = val return out_env + + +def relpath(path, start=None): + """Return a relative version of a path""" + + try: + return op.relpath(path, start) + except AttributeError: + pass + + if start is None: + start = os.curdir + if not path: + raise ValueError("no path specified") + start_list = op.abspath(start).split(op.sep) + path_list = op.abspath(path).split(op.sep) + if start_list[0].lower() != path_list[0].lower(): + unc_path, rest = op.splitunc(path) + unc_start, rest = op.splitunc(start) + if bool(unc_path) ^ bool(unc_start): + raise ValueError(("Cannot mix UNC and non-UNC paths " + "(%s and %s)") % (path, start)) + else: + raise ValueError("path is on drive %s, start on drive %s" + % (path_list[0], start_list[0])) + # Work out how much of the filepath is shared by start and path. + for i in range(min(len(start_list), len(path_list))): + if start_list[i].lower() != path_list[i].lower(): + break + else: + i += 1 + + rel_list = [op.pardir] * (len(start_list) - i) + path_list[i:] + if not rel_list: + return os.curdir + return op.join(*rel_list) diff --git a/nipype/utils/logger.py b/nipype/utils/logger.py index 4604cc4145..2bdc54c791 100644 --- a/nipype/utils/logger.py +++ b/nipype/utils/logger.py @@ -97,42 +97,7 @@ def logdebug_dict_differences(self, dold, dnew, prefix=""): typical use -- log difference for hashed_inputs """ - # First check inputs, since they usually are lists of tuples - # and dicts are required. - if isinstance(dnew, list): - dnew = dict(dnew) - if isinstance(dold, list): - dold = dict(dold) - - # Compare against hashed_inputs - # Keys: should rarely differ - new_keys = set(dnew.keys()) - old_keys = set(dold.keys()) - if len(new_keys - old_keys): - self._logger.debug("%s not previously seen: %s" - % (prefix, new_keys - old_keys)) - if len(old_keys - new_keys): - self._logger.debug("%s not presently seen: %s" - % (prefix, old_keys - new_keys)) - - # Values in common keys would differ quite often, - # so we need to join the messages together - msgs = [] - for k in new_keys.intersection(old_keys): - same = False - try: - new, old = dnew[k], dold[k] - same = new == old - if not same: - # Since JSON does not discriminate between lists and - # tuples, we might need to cast them into the same type - # as the last resort. And lets try to be more generic - same = old.__class__(new) == old - except Exception as e: - same = False - if not same: - msgs += ["%s: %r != %r" - % (k, dnew[k], dold[k])] - if len(msgs): - self._logger.debug("%s values differ in fields: %s" % (prefix, - ", ".join(msgs))) + from .misc import dict_diff + self._logger.warning("logdebug_dict_differences has been deprecated, please use " + "nipype.utils.misc.dict_diff.") + self._logger.debug(dict_diff(dold, dnew)) diff --git a/nipype/utils/misc.py b/nipype/utils/misc.py index 81b29366a1..0d5942940a 100644 --- a/nipype/utils/misc.py +++ b/nipype/utils/misc.py @@ -4,19 +4,29 @@ """Miscellaneous utility functions """ from __future__ import print_function, unicode_literals, division, absolute_import -from future import standard_library -standard_library.install_aliases() from builtins import next, str -from future.utils import raise_from import sys import re from collections import Iterator -import inspect from distutils.version import LooseVersion -from textwrap import dedent + import numpy as np +from future.utils import raise_from +from future import standard_library +try: + from textwrap import indent as textwrap_indent +except ImportError: + def textwrap_indent(text, prefix): + """ A textwrap.indent replacement for Python < 3.3 """ + if not prefix: + return text + splittext = text.splitlines(True) + return prefix + prefix.join(splittext) + +standard_library.install_aliases() + def human_order_sorted(l): """Sorts string in human order (i.e. 'stat10' will go after 'stat2')""" @@ -197,11 +207,11 @@ def unflatten(in_list, prev_structure): if not isinstance(prev_structure, list): return next(in_list) - else: - out = [] - for item in prev_structure: - out.append(unflatten(in_list, item)) - return out + + out = [] + for item in prev_structure: + out.append(unflatten(in_list, item)) + return out def normalize_mc_params(params, source): @@ -229,3 +239,57 @@ def normalize_mc_params(params, source): params[-1:2:-1] = aff2euler(matrix) return params + + +def dict_diff(dold, dnew, indent=0): + """Helper to log what actually changed from old to new values of + dictionaries. + + typical use -- log difference for hashed_inputs + """ + # First check inputs, since they usually are lists of tuples + # and dicts are required. + if isinstance(dnew, list): + dnew = dict(dnew) + if isinstance(dold, list): + dold = dict(dold) + + # Compare against hashed_inputs + # Keys: should rarely differ + new_keys = set(dnew.keys()) + old_keys = set(dold.keys()) + + diff = [] + if new_keys - old_keys: + diff += [" * keys not previously seen: %s" % (new_keys - old_keys)] + + if old_keys - new_keys: + diff += [" * keys not presently seen: %s" % (old_keys - new_keys)] + + # Add topical message + if diff: + diff.insert(0, "Dictionaries had differing keys:") + + diffkeys = len(diff) + + # Values in common keys would differ quite often, + # so we need to join the messages together + for k in new_keys.intersection(old_keys): + same = False + try: + new, old = dnew[k], dold[k] + same = new == old + if not same: + # Since JSON does not discriminate between lists and + # tuples, we might need to cast them into the same type + # as the last resort. And lets try to be more generic + same = old.__class__(new) == old + except Exception: + same = False + if not same: + diff += [" * %s: %r != %r" % (k, dnew[k], dold[k])] + + if len(diff) > diffkeys: + diff.insert(diffkeys, "Some dictionary entries had differing values:") + + return textwrap_indent('\n'.join(diff), ' ' * indent) diff --git a/nipype/utils/tests/test_config.py b/nipype/utils/tests/test_config.py index 869b733c2e..7684bdd55e 100644 --- a/nipype/utils/tests/test_config.py +++ b/nipype/utils/tests/test_config.py @@ -193,3 +193,9 @@ def test_display_empty_macosx(monkeypatch): monkeypatch.setattr(sys, 'platform', 'darwin') with pytest.raises(RuntimeError): config.get_display() + +def test_cwd_cached(tmpdir): + """Check that changing dirs does not change nipype's cwd""" + oldcwd = config.cwd + tmpdir.chdir() + assert config.cwd == oldcwd diff --git a/nipype/workflows/fmri/spm/preprocess.py b/nipype/workflows/fmri/spm/preprocess.py index 384284434d..1a8b8cddee 100644 --- a/nipype/workflows/fmri/spm/preprocess.py +++ b/nipype/workflows/fmri/spm/preprocess.py @@ -8,7 +8,6 @@ from ....interfaces import spm as spm from ....interfaces import utility as niu from ....pipeline import engine as pe -from ....interfaces.matlab import no_matlab from ...smri.freesurfer.utils import create_getmask_flow from .... import logging @@ -141,7 +140,8 @@ def create_vbm_preproc(name='vbmpreproc'): >>> preproc = create_vbm_preproc() >>> preproc.inputs.inputspec.fwhm = 8 - >>> preproc.inputs.inputspec.structural_files = [os.path.abspath('s1.nii'), os.path.abspath('s3.nii')] + >>> preproc.inputs.inputspec.structural_files = [ + ... os.path.abspath('s1.nii'), os.path.abspath('s3.nii')] >>> preproc.inputs.inputspec.template_prefix = 'Template' >>> preproc.run() # doctest: +SKIP @@ -185,7 +185,9 @@ def getclass1images(class_images): class1images.extend(session[0]) return class1images - workflow.connect(dartel_template, ('segment.native_class_images', getclass1images), norm2mni, 'apply_to_files') + workflow.connect( + dartel_template, ('segment.native_class_images', getclass1images), + norm2mni, 'apply_to_files') workflow.connect(inputnode, 'fwhm', norm2mni, 'fwhm') def compute_icv(class_images): @@ -217,10 +219,11 @@ def compute_icv(class_images): "icv" ]), name="outputspec") - workflow.connect([(dartel_template, outputnode, [('outputspec.template_file', 'template_file')]), - (norm2mni, outputnode, [("normalized_files", "normalized_files")]), - (calc_icv, outputnode, [("icv", "icv")]), - ]) + workflow.connect([ + (dartel_template, outputnode, [('outputspec.template_file', 'template_file')]), + (norm2mni, outputnode, [("normalized_files", "normalized_files")]), + (calc_icv, outputnode, [("icv", "icv")]), + ]) return workflow @@ -233,7 +236,8 @@ def create_DARTEL_template(name='dartel_template'): ------- >>> preproc = create_DARTEL_template() - >>> preproc.inputs.inputspec.structural_files = [os.path.abspath('s1.nii'), os.path.abspath('s3.nii')] + >>> preproc.inputs.inputspec.structural_files = [ + ... os.path.abspath('s1.nii'), os.path.abspath('s3.nii')] >>> preproc.inputs.inputspec.template_prefix = 'Template' >>> preproc.run() # doctest: +SKIP @@ -259,24 +263,34 @@ def create_DARTEL_template(name='dartel_template'): name='segment') workflow.connect(inputnode, 'structural_files', segment, 'channel_files') - version = spm.Info.version() - if version: - spm_path = version['path'] - if version['name'] == 'SPM8': - tissue1 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 1), 2, (True, True), (False, False)) - tissue2 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 2), 2, (True, True), (False, False)) - tissue3 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 3), 2, (True, False), (False, False)) - tissue4 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 4), 3, (False, False), (False, False)) - tissue5 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 5), 4, (False, False), (False, False)) - tissue6 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 6), 2, (False, False), (False, False)) - elif version['name'] == 'SPM12': - spm_path = version['path'] + spm_info = spm.Info.getinfo() + if spm_info: + spm_path = spm_info['path'] + if spm_info['name'] == 'SPM8': + tissue1 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 1), + 2, (True, True), (False, False)) + tissue2 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 2), + 2, (True, True), (False, False)) + tissue3 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 3), + 2, (True, False), (False, False)) + tissue4 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 4), + 3, (False, False), (False, False)) + tissue5 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 5), + 4, (False, False), (False, False)) + tissue6 = ((os.path.join(spm_path, 'toolbox/Seg/TPM.nii'), 6), + 2, (False, False), (False, False)) + elif spm_info['name'] == 'SPM12': + spm_path = spm_info['path'] tissue1 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 1), 1, (True, True), (False, False)) tissue2 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 2), 1, (True, True), (False, False)) - tissue3 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 3), 2, (True, False), (False, False)) - tissue4 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 4), 3, (False, False), (False, False)) - tissue5 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 5), 4, (False, False), (False, False)) - tissue6 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 6), 2, (False, False), (False, False)) + tissue3 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 3), + 2, (True, False), (False, False)) + tissue4 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 4), + 3, (False, False), (False, False)) + tissue5 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 5), + 4, (False, False), (False, False)) + tissue6 = ((os.path.join(spm_path, 'tpm/TPM.nii'), 6), + 2, (False, False), (False, False)) else: logger.critical('Unsupported version of SPM')