diff --git a/CHANGES b/CHANGES index b21025bfdb..f4ebb73be0 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,7 @@ Next Release * API: Interfaces to external packages are no longer available in the top-level ``nipype`` namespace, and must be imported directly (e.g. ``from nipype.interfaces import fsl``). * ENH: New ANTs interface: ApplyTransformsToPoints * ENH: New FreeSurfer workflow: create_skullstripped_recon_flow() +* ENH: New data grabbing interface that works over SSH connections, SSHDataGrabber * FIX: MRTrix tracking algorithms were ignoring mask parameters. Release 0.9.2 (January 31, 2014) diff --git a/nipype/interfaces/io.py b/nipype/interfaces/io.py index 7f14b0e796..1f5e508abb 100644 --- a/nipype/interfaces/io.py +++ b/nipype/interfaces/io.py @@ -18,10 +18,12 @@ """ import glob +import fnmatch import string import os import os.path as op import shutil +import subprocess import re import tempfile from warnings import warn @@ -34,6 +36,11 @@ except: pass +try: + import paramiko +except: + pass + from nipype.interfaces.base import (TraitedSpec, traits, File, Directory, BaseInterface, InputMultiPath, isdefined, OutputMultiPath, DynamicTraitedSpec, @@ -750,7 +757,7 @@ class DataFinder(IOBase): '013-ep2d_fid_T1_pre'] >>> print result.outputs.basename # doctest: +SKIP ['acquisition', - 'acquisition', + 'acquisition' 'acquisition', 'acquisition'] @@ -1539,3 +1546,260 @@ def _list_outputs(self): conn.commit() c.close() return None + +class SSHDataGrabberInputSpec(DataGrabberInputSpec): + hostname = traits.Str(mandatory=True, + desc='Server hostname.') + username = traits.Str(mandatory=False, + desc='Server username.') + password = traits.Password(mandatory=False, + desc='Server password.') + download_files = traits.Bool(True, usedefault=True, + desc='If false it will return the file names without downloading them') + base_directory = traits.Str(mandatory=True, + desc='Path to the base directory consisting of subject data.') + template_expression = traits.Enum(['fnmatch', 'regexp'], usedefault=True, + desc='Use either fnmatch or regexp to express templates') + ssh_log_to_file = traits.Str('', usedefault=True, + desc='If set SSH commands will be logged to the given file') + + +class SSHDataGrabber(DataGrabber): + """ Extension of DataGrabber module that downloads the file list and + optionally the files from a SSH server. The SSH operation must + not need user and password so an SSH agent must be active in + where this module is being run. + + + .. attention:: + + Doesn't support directories currently + + Examples + -------- + + >>> from nipype.interfaces.io import SSHDataGrabber + >>> dg = SSHDataGrabber() + >>> dg.inputs.hostname = 'test.rebex.net' + >>> dg.inputs.user = 'demo' + >>> dg.inputs.password = 'password' + >>> dg.inputs.base_directory = 'pub/example' + + Pick all files from the base directory + + >>> dg.inputs.template = '*' + + Pick all files starting with "s" and a number from current directory + + >>> dg.inputs.template_expression = 'regexp' + >>> dg.inputs.template = 'pop[0-9].*' + + Same thing but with dynamically created fields + + >>> dg = SSHDataGrabber(infields=['arg1','arg2']) + >>> dg.inputs.hostname = 'test.rebex.net' + >>> dg.inputs.user = 'demo' + >>> dg.inputs.password = 'password' + >>> dg.inputs.base_directory = 'pub' + >>> dg.inputs.template = '%s/%s.txt' + >>> dg.inputs.arg1 = 'example' + >>> dg.inputs.arg2 = 'foo' + + however this latter form can be used with iterables and iterfield in a + pipeline. + + Dynamically created, user-defined input and output fields + + >>> dg = SSHDataGrabber(infields=['sid'], outfields=['func','struct','ref']) + >>> dg.inputs.hostname = 'myhost.com' + >>> dg.inputs.base_directory = '/main_folder/my_remote_dir' + >>> dg.inputs.template_args['func'] = [['sid',['f3','f5']]] + >>> dg.inputs.template_args['struct'] = [['sid',['struct']]] + >>> dg.inputs.template_args['ref'] = [['sid','ref']] + >>> dg.inputs.sid = 's1' + + Change the template only for output field struct. The rest use the + general template + + >>> dg.inputs.field_template = dict(struct='%s/struct.nii') + >>> dg.inputs.template_args['struct'] = [['sid']] + + """ + input_spec = SSHDataGrabberInputSpec + output_spec = DynamicTraitedSpec + _always_run = False + + def __init__(self, infields=None, outfields=None, **kwargs): + """ + Parameters + ---------- + infields : list of str + Indicates the input fields to be dynamically created + + outfields: list of str + Indicates output fields to be dynamically created + + See class examples for usage + + """ + try: + paramiko + except NameError: + warn( + "The library parmiko needs to be installed" + " for this module to run." + ) + if not outfields: + outfields = ['outfiles'] + kwargs = kwargs.copy() + kwargs['infields'] = infields + kwargs['outfields'] = outfields + super(SSHDataGrabber, self).__init__(**kwargs) + if ( + None in (self.inputs.username, self.inputs.password) + ): + raise ValueError( + "either both username and password " + "are provided or none of them" + ) + + if ( + self.inputs.template_expression == 'regexp' and + self.inputs.template[-1] != '$' + ): + self.inputs.template += '$' + + + def _list_outputs(self): + try: + paramiko + except NameError: + raise ImportError( + "The library parmiko needs to be installed" + " for this module to run." + ) + + if len(self.inputs.ssh_log_to_file) > 0: + paramiko.util.log_to_file(self.inputs.ssh_log_to_file) + # infields are mandatory, however I could not figure out how to set 'mandatory' flag dynamically + # hence manual check + if self._infields: + for key in self._infields: + value = getattr(self.inputs, key) + if not isdefined(value): + msg = "%s requires a value for input '%s' because it was listed in 'infields'" % \ + (self.__class__.__name__, key) + raise ValueError(msg) + + outputs = {} + for key, args in self.inputs.template_args.items(): + outputs[key] = [] + template = self.inputs.template + if hasattr(self.inputs, 'field_template') and \ + isdefined(self.inputs.field_template) and \ + key in self.inputs.field_template: + template = self.inputs.field_template[key] + if not args: + client = self._get_ssh_client() + sftp = client.open_sftp() + sftp.chdir(self.inputs.base_directory) + filelist = sftp.listdir() + if self.inputs.template_expression == 'fnmatch': + filelist = fnmatch.filter(filelist, template) + elif self.inputs.template_expression == 'regexp': + regexp = re.compile(template) + filelist = filter(regexp.match, filelist) + else: + raise ValueError('template_expression value invalid') + if len(filelist) == 0: + msg = 'Output key: %s Template: %s returned no files' % ( + key, template) + if self.inputs.raise_on_empty: + raise IOError(msg) + else: + warn(msg) + else: + if self.inputs.sort_filelist: + filelist = human_order_sorted(filelist) + outputs[key] = list_to_filename(filelist) + if self.inputs.download_files: + for f in filelist: + sftp.get(f, f) + for argnum, arglist in enumerate(args): + maxlen = 1 + for arg in arglist: + if isinstance(arg, str) and hasattr(self.inputs, arg): + arg = getattr(self.inputs, arg) + if isinstance(arg, list): + if (maxlen > 1) and (len(arg) != maxlen): + raise ValueError('incompatible number of arguments for %s' % key) + if len(arg) > maxlen: + maxlen = len(arg) + outfiles = [] + for i in range(maxlen): + argtuple = [] + for arg in arglist: + if isinstance(arg, str) and hasattr(self.inputs, arg): + arg = getattr(self.inputs, arg) + if isinstance(arg, list): + argtuple.append(arg[i]) + else: + argtuple.append(arg) + filledtemplate = template + if argtuple: + try: + filledtemplate = template % tuple(argtuple) + except TypeError as e: + raise TypeError(e.message + ": Template %s failed to convert with args %s" % (template, str(tuple(argtuple)))) + client = self._get_ssh_client() + sftp = client.open_sftp() + sftp.chdir(self.inputs.base_directory) + filledtemplate_dir = os.path.dirname(filledtemplate) + filledtemplate_base = os.path.basename(filledtemplate) + filelist = sftp.listdir(filledtemplate_dir) + if self.inputs.template_expression == 'fnmatch': + outfiles = fnmatch.filter(filelist, filledtemplate_base) + elif self.inputs.template_expression == 'regexp': + regexp = re.compile(filledtemplate_base) + outfiles = filter(regexp.match, filelist) + else: + raise ValueError('template_expression value invalid') + if len(outfiles) == 0: + msg = 'Output key: %s Template: %s returned no files' % (key, filledtemplate) + if self.inputs.raise_on_empty: + raise IOError(msg) + else: + warn(msg) + outputs[key].append(None) + else: + if self.inputs.sort_filelist: + outfiles = human_order_sorted(outfiles) + outputs[key].append(list_to_filename(outfiles)) + if self.inputs.download_files: + for f in outfiles: + sftp.get(os.path.join(filledtemplate_dir, f), f) + if any([val is None for val in outputs[key]]): + outputs[key] = [] + if len(outputs[key]) == 0: + outputs[key] = None + elif len(outputs[key]) == 1: + outputs[key] = outputs[key][0] + return outputs + + def _get_ssh_client(self): + config = paramiko.SSHConfig() + config.parse(open(os.path.expanduser('~/.ssh/config'))) + host = config.lookup(self.inputs.hostname) + if 'proxycommand' in host: + proxy = paramiko.ProxyCommand( + subprocess.check_output( + [os.environ['SHELL'], '-c', 'echo %s' % host['proxycommand']] + ).strip() + ) + else: + proxy = None + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect(host['hostname'], username=host['user'], sock=proxy) + return client