Skip to content

Commit 6d7af13

Browse files
committed
ENH: Added a datagrabber capable of grabbing files from an SSH access
1 parent 4b50091 commit 6d7af13

File tree

1 file changed

+273
-1
lines changed

1 file changed

+273
-1
lines changed

nipype/interfaces/io.py

Lines changed: 273 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
1919
"""
2020
import glob
21+
import fnmatch
2122
import string
2223
import os
2324
import os.path as op
2425
import shutil
26+
import subprocess
2527
import re
2628
import tempfile
2729
from warnings import warn
@@ -34,6 +36,11 @@
3436
except:
3537
pass
3638

39+
try:
40+
import paramiko
41+
except:
42+
pass
43+
3744
from nipype.interfaces.base import (TraitedSpec, traits, File, Directory,
3845
BaseInterface, InputMultiPath, isdefined,
3946
OutputMultiPath, DynamicTraitedSpec,
@@ -750,7 +757,7 @@ class DataFinder(IOBase):
750757
'013-ep2d_fid_T1_pre']
751758
>>> print result.outputs.basename # doctest: +SKIP
752759
['acquisition',
753-
'acquisition',
760+
'acquisition'
754761
'acquisition',
755762
'acquisition']
756763
@@ -1539,3 +1546,268 @@ def _list_outputs(self):
15391546
conn.commit()
15401547
c.close()
15411548
return None
1549+
1550+
class SSHDataGrabberInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
1551+
hostname = traits.Str(mandatory=True,
1552+
desc='Server hostname.')
1553+
download_files = traits.Bool(True, usedefault=True,
1554+
desc='If false it will return the file names without downloading them')
1555+
base_directory = traits.Str(mandatory=True,
1556+
desc='Path to the base directory consisting of subject data.')
1557+
raise_on_empty = traits.Bool(True, usedefault=True,
1558+
desc='Generate exception if list is empty for a given field')
1559+
sort_filelist = traits.Bool(mandatory=True,
1560+
desc='Sort the filelist that matches the template')
1561+
template = traits.Str(mandatory=True,
1562+
desc='Layout used to get files. relative to base directory if defined')
1563+
template_args = traits.Dict(key_trait=traits.Str,
1564+
value_trait=traits.List(traits.List),
1565+
desc='Information to plug into template')
1566+
template_expression = traits.Enum(['fnmatch', 'regexp'], usedefault=True,
1567+
desc='Use either fnmatch or regexp to express templates')
1568+
1569+
ssh_log_to_file = traits.Str('', usedefault=True,
1570+
desc='If set SSH commands will be logged to the given file')
1571+
1572+
1573+
class SSHDataGrabber(IOBase):
1574+
""" Datagrabber module that downloads the file list and optionally
1575+
the files from a SSH server. The SSH operation must not need
1576+
user and password so an SSH agent must be active in where this
1577+
module is being run.
1578+
1579+
1580+
.. attention::
1581+
1582+
Doesn't support directories currently
1583+
1584+
Examples
1585+
--------
1586+
1587+
>>> from nipype.interfaces.io import SSHDataGrabber
1588+
>>> dg = SSHDataGrabber()
1589+
>>> dg.inputs.hostname = 'myhost.com'
1590+
>>> dg.inputs.base_directory = '/main_folder/my_remote_dir'
1591+
1592+
Pick all files from the base directory
1593+
1594+
>>> dg.inputs.template = '*'
1595+
1596+
Pick all files starting with "s" and a number from current directory
1597+
1598+
>>> dg.inputs.template_expression = 'regexp'
1599+
>>> dg.inputs.template = 's[0-9].*'
1600+
1601+
Same thing but with dynamically created fields
1602+
1603+
>>> dg = SSHDataGrabber(infields=['arg1','arg2'])
1604+
>>> dg.inputs.hostname = 'myhost.com'
1605+
>>> dg.inputs.base_directory = '~/my_remote_dir'
1606+
>>> dg.inputs.template = '%s/%s.nii'
1607+
>>> dg.inputs.arg1 = 'foo'
1608+
>>> dg.inputs.arg2 = 'foo'
1609+
1610+
however this latter form can be used with iterables and iterfield in a
1611+
pipeline.
1612+
1613+
Dynamically created, user-defined input and output fields
1614+
1615+
>>> dg = SSHDataGrabber(infields=['sid'], outfields=['func','struct','ref'])
1616+
>>> dg.inputs.hostname = 'myhost.com'
1617+
>>> dg.inputs.base_directory = '/main_folder/my_remote_dir'
1618+
>>> dg.inputs.template_args['func'] = [['sid',['f3','f5']]]
1619+
>>> dg.inputs.template_args['struct'] = [['sid',['struct']]]
1620+
>>> dg.inputs.template_args['ref'] = [['sid','ref']]
1621+
>>> dg.inputs.sid = 's1'
1622+
1623+
Change the template only for output field struct. The rest use the
1624+
general template
1625+
1626+
>>> dg.inputs.field_template = dict(struct='%s/struct.nii')
1627+
>>> dg.inputs.template_args['struct'] = [['sid']]
1628+
1629+
"""
1630+
input_spec = SSHDataGrabberInputSpec
1631+
output_spec = DynamicTraitedSpec
1632+
_always_run = True
1633+
1634+
def __init__(self, infields=None, outfields=None, **kwargs):
1635+
"""
1636+
Parameters
1637+
----------
1638+
infields : list of str
1639+
Indicates the input fields to be dynamically created
1640+
1641+
outfields: list of str
1642+
Indicates output fields to be dynamically created
1643+
1644+
See class examples for usage
1645+
1646+
"""
1647+
if not outfields:
1648+
outfields = ['outfiles']
1649+
super(SSHDataGrabber, self).__init__(**kwargs)
1650+
undefined_traits = {}
1651+
# used for mandatory inputs check
1652+
self._infields = infields
1653+
self._outfields = outfields
1654+
if infields:
1655+
for key in infields:
1656+
self.inputs.add_trait(key, traits.Any)
1657+
undefined_traits[key] = Undefined
1658+
# add ability to insert field specific templates
1659+
self.inputs.add_trait('field_template',
1660+
traits.Dict(traits.Enum(outfields),
1661+
desc="arguments that fit into template"))
1662+
undefined_traits['field_template'] = Undefined
1663+
if not isdefined(self.inputs.template_args):
1664+
self.inputs.template_args = {}
1665+
for key in outfields:
1666+
if not key in self.inputs.template_args:
1667+
if infields:
1668+
self.inputs.template_args[key] = [infields]
1669+
else:
1670+
self.inputs.template_args[key] = []
1671+
1672+
self.inputs.trait_set(trait_change_notify=False, **undefined_traits)
1673+
1674+
if (
1675+
self.inputs.template_expression == 'regexp' and
1676+
self.inputs.template[-1] != '$'
1677+
):
1678+
self.inputs.template += '$'
1679+
1680+
1681+
1682+
def _add_output_traits(self, base):
1683+
"""
1684+
1685+
Using traits.Any instead out OutputMultiPath till add_trait bug
1686+
is fixed.
1687+
"""
1688+
return add_traits(base, self.inputs.template_args.keys())
1689+
1690+
def _list_outputs(self):
1691+
if len(self.inputs.ssh_log_to_file) > 0:
1692+
paramiko.util.log_to_file(self.inputs.ssh_log_to_file)
1693+
# infields are mandatory, however I could not figure out how to set 'mandatory' flag dynamically
1694+
# hence manual check
1695+
if self._infields:
1696+
for key in self._infields:
1697+
value = getattr(self.inputs, key)
1698+
if not isdefined(value):
1699+
msg = "%s requires a value for input '%s' because it was listed in 'infields'" % \
1700+
(self.__class__.__name__, key)
1701+
raise ValueError(msg)
1702+
1703+
outputs = {}
1704+
for key, args in self.inputs.template_args.items():
1705+
outputs[key] = []
1706+
template = self.inputs.template
1707+
if hasattr(self.inputs, 'field_template') and \
1708+
isdefined(self.inputs.field_template) and \
1709+
key in self.inputs.field_template:
1710+
template = self.inputs.field_template[key]
1711+
#template = os.path.join(
1712+
# os.path.abspath(self.inputs.base_directory), template)
1713+
if not args:
1714+
client = self._get_ssh_client()
1715+
sftp = client.open_sftp()
1716+
sftp.chdir(self.inputs.base_directory)
1717+
filelist = sftp.listdir()
1718+
if self.inputs.template_expression == 'fnmatch':
1719+
filelist = fnmatch.filter(filelist, template)
1720+
elif self.inputs.template_expression == 'regexp':
1721+
regexp = re.compile(template)
1722+
filelist = filter(regexp.match, filelist)
1723+
else:
1724+
raise ValueError('template_expression value invalid')
1725+
if len(filelist) == 0:
1726+
msg = 'Output key: %s Template: %s returned no files' % (
1727+
key, template)
1728+
if self.inputs.raise_on_empty:
1729+
raise IOError(msg)
1730+
else:
1731+
warn(msg)
1732+
else:
1733+
if self.inputs.sort_filelist:
1734+
filelist = human_order_sorted(filelist)
1735+
outputs[key] = list_to_filename(filelist)
1736+
if self.inputs.download_files:
1737+
for f in filelist:
1738+
sftp.get(f, f)
1739+
for argnum, arglist in enumerate(args):
1740+
maxlen = 1
1741+
for arg in arglist:
1742+
if isinstance(arg, str) and hasattr(self.inputs, arg):
1743+
arg = getattr(self.inputs, arg)
1744+
if isinstance(arg, list):
1745+
if (maxlen > 1) and (len(arg) != maxlen):
1746+
raise ValueError('incompatible number of arguments for %s' % key)
1747+
if len(arg) > maxlen:
1748+
maxlen = len(arg)
1749+
outfiles = []
1750+
for i in range(maxlen):
1751+
argtuple = []
1752+
for arg in arglist:
1753+
if isinstance(arg, str) and hasattr(self.inputs, arg):
1754+
arg = getattr(self.inputs, arg)
1755+
if isinstance(arg, list):
1756+
argtuple.append(arg[i])
1757+
else:
1758+
argtuple.append(arg)
1759+
filledtemplate = template
1760+
if argtuple:
1761+
try:
1762+
filledtemplate = template % tuple(argtuple)
1763+
except TypeError as e:
1764+
raise TypeError(e.message + ": Template %s failed to convert with args %s" % (template, str(tuple(argtuple))))
1765+
client = self._get_ssh_client()
1766+
sftp = client.open_sftp()
1767+
sftp.chdir(self.inputs.base_directory)
1768+
filledtemplate_dir = os.path.dirname(filledtemplate)
1769+
filledtemplate_base = os.path.basename(filledtemplate)
1770+
filelist = sftp.listdir(filledtemplate_dir)
1771+
if self.inputs.template_expression == 'fnmatch':
1772+
outfiles = fnmatch.filter(filelist, filledtemplate_base)
1773+
elif self.inputs.template_expression == 'regexp':
1774+
regexp = re.compile(filledtemplate_base)
1775+
outfiles = filter(regexp.match, filelist)
1776+
else:
1777+
raise ValueError('template_expression value invalid')
1778+
if len(outfiles) == 0:
1779+
msg = 'Output key: %s Template: %s returned no files' % (key, filledtemplate)
1780+
if self.inputs.raise_on_empty:
1781+
raise IOError(msg)
1782+
else:
1783+
warn(msg)
1784+
outputs[key].append(None)
1785+
else:
1786+
if self.inputs.sort_filelist:
1787+
outfiles = human_order_sorted(outfiles)
1788+
outputs[key].append(list_to_filename(outfiles))
1789+
if self.inputs.download_files:
1790+
for f in outfiles:
1791+
sftp.get(os.path.join(filledtemplate_dir, f), f)
1792+
if any([val is None for val in outputs[key]]):
1793+
outputs[key] = []
1794+
if len(outputs[key]) == 0:
1795+
outputs[key] = None
1796+
elif len(outputs[key]) == 1:
1797+
outputs[key] = outputs[key][0]
1798+
return outputs
1799+
1800+
def _get_ssh_client(self):
1801+
config = paramiko.SSHConfig()
1802+
config.parse(open(os.path.expanduser('~/.ssh/config')))
1803+
host = config.lookup(self.inputs.hostname)
1804+
proxy = paramiko.ProxyCommand(
1805+
subprocess.check_output(
1806+
[os.environ['SHELL'], '-c', 'echo %s' % host['proxycommand']]
1807+
).strip()
1808+
)
1809+
client = paramiko.SSHClient()
1810+
client.load_system_host_keys()
1811+
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
1812+
client.connect(host['hostname'], username=host['user'], sock=proxy)
1813+
return client

0 commit comments

Comments
 (0)