diff --git a/pandas/io/common.py b/pandas/io/common.py index 1fc572dbf1a5e..33958ade2bcd6 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -1,9 +1,14 @@ -""" Common api utilities """ +"""Common IO api utilities""" +import sys import urlparse -from pandas.util import py3compat +import urllib2 +import zipfile +from contextlib import contextmanager, closing from StringIO import StringIO +from pandas.util import py3compat + _VALID_URLS = set(urlparse.uses_relative + urlparse.uses_netloc + urlparse.uses_params) _VALID_URLS.discard('') @@ -84,3 +89,24 @@ def get_filepath_or_buffer(filepath_or_buffer, encoding=None): return filepath_or_buffer, None return filepath_or_buffer, None + + +# ---------------------- +# Prevent double closing +if py3compat.PY3: + urlopen = urllib2.urlopen +else: + @contextmanager + def urlopen(*args, **kwargs): + with closing(urllib2.urlopen(*args, **kwargs)) as f: + yield f + +# ZipFile is not a context manager for <= 2.6 +# must be tuple index here since 2.6 doesn't use namedtuple for version_info +if sys.version_info[1] <= 6: + @contextmanager + def ZipFile(*args, **kwargs): + with closing(zipfile.ZipFile(*args, **kwargs)) as zf: + yield zf +else: + ZipFile = zipfile.ZipFile diff --git a/pandas/io/data.py b/pandas/io/data.py index 278fc2fc6dd4d..2d91bd4cd383c 100644 --- a/pandas/io/data.py +++ b/pandas/io/data.py @@ -5,19 +5,29 @@ """ import warnings import tempfile - -import numpy as np +import itertools import datetime as dt import urllib import time + from collections import defaultdict -from contextlib import closing -from urllib2 import urlopen -from zipfile import ZipFile +import numpy as np + from pandas.util.py3compat import StringIO, bytes_to_str from pandas import Panel, DataFrame, Series, read_csv, concat +from pandas.core.common import PandasError from pandas.io.parsers import TextParser +from pandas.io.common import urlopen, ZipFile +from pandas.util.testing import _network_error_classes + + +class SymbolWarning(UserWarning): + pass + + +class RemoteDataError(PandasError, IOError): + pass def DataReader(name, data_source=None, start=None, end=None, @@ -58,16 +68,16 @@ def DataReader(name, data_source=None, start=None, end=None, if data_source == "yahoo": return get_data_yahoo(symbols=name, start=start, end=end, - adjust_price=False, chunk=25, + adjust_price=False, chunksize=25, retry_count=retry_count, pause=pause) elif data_source == "google": return get_data_google(symbols=name, start=start, end=end, - adjust_price=False, chunk=25, + adjust_price=False, chunksize=25, retry_count=retry_count, pause=pause) elif data_source == "fred": - return get_data_fred(name=name, start=start, end=end) + return get_data_fred(name, start, end) elif data_source == "famafrench": - return get_data_famafrench(name=name) + return get_data_famafrench(name) def _sanitize_dates(start, end): @@ -88,6 +98,9 @@ def _in_chunks(seq, size): return (seq[pos:pos + size] for pos in xrange(0, len(seq), size)) +_yahoo_codes = {'symbol': 's', 'last': 'l1', 'change_pct': 'p2', 'PE': 'r', + 'time': 't1', 'short_ratio': 's7'} + def get_quote_yahoo(symbols): """ Get current yahoo quote @@ -96,24 +109,19 @@ def get_quote_yahoo(symbols): """ if isinstance(symbols, basestring): sym_list = symbols - elif not isinstance(symbols, Series): - symbols = Series(symbols) - sym_list = '+'.join(symbols) else: sym_list = '+'.join(symbols) # for codes see: http://www.gummy-stuff.org/Yahoo-data.htm - codes = {'symbol': 's', 'last': 'l1', 'change_pct': 'p2', 'PE': 'r', - 'time': 't1', 'short_ratio': 's7'} - request = ''.join(codes.itervalues()) # code request string - header = codes.keys() + request = ''.join(_yahoo_codes.itervalues()) # code request string + header = _yahoo_codes.keys() data = defaultdict(list) url_str = 'http://finance.yahoo.com/d/quotes.csv?s=%s&f=%s' % (sym_list, request) - with closing(urlopen(url_str)) as url: + with urlopen(url_str) as url: lines = url.readlines() for line in lines: @@ -131,7 +139,6 @@ def get_quote_yahoo(symbols): data[header[i]].append(v) idx = data.pop('symbol') - return DataFrame(data, index=idx) @@ -139,8 +146,30 @@ def get_quote_google(symbols): raise NotImplementedError("Google Finance doesn't have this functionality") -def _get_hist_yahoo(sym, start=None, end=None, retry_count=3, pause=0.001, - **kwargs): +def _retry_read_url(url, retry_count, pause, name): + for _ in xrange(retry_count): + time.sleep(pause) + + # kludge to close the socket ASAP + try: + with urlopen(url) as resp: + lines = resp.read() + except _network_error_classes: + pass + else: + rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0, + parse_dates=True)[::-1] + # Yahoo! Finance sometimes does this awesome thing where they + # return 2 rows for the most recent business day + if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover + rs = rs[:-1] + return rs + + raise IOError("after %d tries, %s did not " + "return a 200 for url %r" % (retry_count, name, url)) + + +def _get_hist_yahoo(sym, start, end, retry_count, pause): """ Get historical data for the given name from yahoo. Date format is datetime @@ -148,10 +177,8 @@ def _get_hist_yahoo(sym, start=None, end=None, retry_count=3, pause=0.001, Returns a DataFrame. """ start, end = _sanitize_dates(start, end) - - yahoo_URL = 'http://ichart.yahoo.com/table.csv?' - - url = (yahoo_URL + 's=%s' % sym + + yahoo_url = 'http://ichart.yahoo.com/table.csv?' + url = (yahoo_url + 's=%s' % sym + '&a=%s' % (start.month - 1) + '&b=%s' % start.day + '&c=%s' % start.year + @@ -160,29 +187,10 @@ def _get_hist_yahoo(sym, start=None, end=None, retry_count=3, pause=0.001, '&f=%s' % end.year + '&g=d' + '&ignore=.csv') - - for _ in xrange(retry_count): - with closing(urlopen(url)) as resp: - if resp.code == 200: - lines = resp.read() - rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0, - parse_dates=True)[::-1] - - # Yahoo! Finance sometimes does this awesome thing where they - # return 2 rows for the most recent business day - if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover - rs = rs[:-1] - - return rs - - time.sleep(pause) - - raise IOError("after %d tries, Yahoo did not " - "return a 200 for url %r" % (retry_count, url)) + return _retry_read_url(url, retry_count, pause, 'Yahoo!') -def _get_hist_google(sym, start=None, end=None, retry_count=3, pause=0.001, - **kwargs): +def _get_hist_google(sym, start, end, retry_count, pause): """ Get historical data for the given name from google. Date format is datetime @@ -190,7 +198,6 @@ def _get_hist_google(sym, start=None, end=None, retry_count=3, pause=0.001, Returns a DataFrame. """ start, end = _sanitize_dates(start, end) - google_URL = 'http://www.google.com/finance/historical?' # www.google.com/finance/historical?q=GOOG&startdate=Jun+9%2C+2011&enddate=Jun+8%2C+2013&output=csv @@ -199,25 +206,16 @@ def _get_hist_google(sym, start=None, end=None, retry_count=3, pause=0.001, '%Y'), "enddate": end.strftime('%b %d, %Y'), "output": "csv"}) - for _ in xrange(retry_count): - with closing(urlopen(url)) as resp: - if resp.code == 200: - rs = read_csv(StringIO(bytes_to_str(resp.read())), index_col=0, - parse_dates=True)[::-1] - - return rs - - time.sleep(pause) - - raise IOError("after %d tries, Google did not " - "return a 200 for url %s" % (retry_count, url)) + return _retry_read_url(url, retry_count, pause, 'Google') -def _adjust_prices(hist_data, price_list=['Open', 'High', 'Low', 'Close']): +def _adjust_prices(hist_data, price_list=None): """ Return modifed DataFrame or Panel with adjusted prices based on 'Adj Close' price. Adds 'Adj_Ratio' column. """ + if price_list is None: + price_list = 'Open', 'High', 'Low', 'Close' adj_ratio = hist_data['Adj Close'] / hist_data['Close'] data = hist_data.copy() @@ -234,7 +232,7 @@ def _calc_return_index(price_df): (typically NaN) is set to 1. """ df = price_df.pct_change().add(1).cumprod() - mask = ~df.ix[1].isnull() & df.ix[0].isnull() + mask = df.ix[1].notnull() & df.ix[0].isnull() df.ix[0][mask] = 1 # Check for first stock listings after starting date of index in ret_index @@ -245,8 +243,7 @@ def _calc_return_index(price_df): t_idx = df.index.get_loc(tstamp) - 1 df[sym].ix[t_idx] = 1 - ret_index = df - return ret_index + return df def get_components_yahoo(idx_sym): @@ -287,7 +284,7 @@ def get_components_yahoo(idx_sym): # break when no new components are found while True in mask: url_str = url.format(idx_mod, stats, comp_idx) - with closing(urlopen(url_str)) as resp: + with urlopen(url_str) as resp: raw = resp.read() lines = raw.decode('utf-8').strip().strip('"').split('"\r\n"') lines = [line.strip().split('","') for line in lines] @@ -303,22 +300,54 @@ def get_components_yahoo(idx_sym): return idx_df -def _dl_mult_symbols(symbols, start, end, chunksize, pause, method, **kwargs): +def _dl_mult_symbols(symbols, start, end, chunksize, retry_count, pause, + method): stocks = {} for sym_group in _in_chunks(symbols, chunksize): for sym in sym_group: try: - stocks[sym] = method(sym, start=start, end=end, pause=pause, - **kwargs) + stocks[sym] = method(sym, start, end, retry_count, pause) except IOError: - warnings.warn('ERROR with symbol: {0}, skipping.'.format(sym)) + warnings.warn('Failed to read symbol: {0!r}, replacing with ' + 'NaN.'.format(sym), SymbolWarning) + stocks[sym] = np.nan return Panel(stocks).swapaxes('items', 'minor') +_source_functions = {'google': _get_hist_google, 'yahoo': _get_hist_yahoo} + +def _get_data_from(symbols, start, end, retry_count, pause, adjust_price, + ret_index, chunksize, source, name): + if name is not None: + warnings.warn("Arg 'name' is deprecated, please use 'symbols' " + "instead.", FutureWarning) + symbols = name + + src_fn = _source_functions[source] + + # If a single symbol, (e.g., 'GOOG') + if isinstance(symbols, (basestring, int)): + hist_data = src_fn(symbols, start, end, retry_count, pause) + # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) + elif isinstance(symbols, DataFrame): + hist_data = _dl_mult_symbols(symbols.index, start, end, chunksize, + retry_count, pause, src_fn) + else: + hist_data = _dl_mult_symbols(symbols, start, end, chunksize, + retry_count, pause, src_fn) + if source.lower() == 'yahoo': + if ret_index: + hist_data['Ret_Index'] = _calc_return_index(hist_data['Adj Close']) + if adjust_price: + hist_data = _adjust_prices(hist_data) + + return hist_data + + def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, pause=0.001, adjust_price=False, ret_index=False, - chunksize=25, **kwargs): + chunksize=25, name=None): """ Returns DataFrame/Panel of historical stock prices from symbols, over date range, start to end. To avoid being penalized by Yahoo! Finance servers, @@ -352,32 +381,13 @@ def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, ------- hist_data : DataFrame (str) or Panel (array-like object, DataFrame) """ - if 'name' in kwargs: - warnings.warn("Arg 'name' is deprecated, please use 'symbols' " - "instead.", FutureWarning) - symbols = kwargs['name'] - - # If a single symbol, (e.g., 'GOOG') - if isinstance(symbols, (basestring, int)): - hist_data = _get_hist_yahoo(symbols, start=start, end=end) - # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) - elif isinstance(symbols, DataFrame): - hist_data = _dl_mult_symbols(symbols.index, start, end, chunksize, - pause, _get_hist_yahoo, **kwargs) - else: - hist_data = _dl_mult_symbols(symbols, start, end, chunksize, pause, - _get_hist_yahoo, **kwargs) - - if ret_index: - hist_data['Ret_Index'] = _calc_return_index(hist_data['Adj Close']) - if adjust_price: - hist_data = _adjust_prices(hist_data) - - return hist_data + return _get_data_from(symbols, start, end, retry_count, pause, + adjust_price, ret_index, chunksize, 'yahoo', name) def get_data_google(symbols=None, start=None, end=None, retry_count=3, - pause=0.001, chunksize=25, **kwargs): + pause=0.001, adjust_price=False, ret_index=False, + chunksize=25, name=None): """ Returns DataFrame/Panel of historical stock prices from symbols, over date range, start to end. To avoid being penalized by Google Finance servers, @@ -405,21 +415,8 @@ def get_data_google(symbols=None, start=None, end=None, retry_count=3, ------- hist_data : DataFrame (str) or Panel (array-like object, DataFrame) """ - if 'name' in kwargs: - warnings.warn("Arg 'name' is deprecated, please use 'symbols' " - "instead.", FutureWarning) - symbols = kwargs['name'] - - # If a single symbol, (e.g., 'GOOG') - if isinstance(symbols, (basestring, int)): - return _get_hist_google(symbols, start=start, end=end) - # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) - elif isinstance(symbols, DataFrame): - symbs = symbols.index - else: # Guess a Series - symbs = symbols - return _dl_mult_symbols(symbs, start, end, chunksize, pause, - _get_hist_google, **kwargs) + return _get_data_from(symbols, start, end, retry_count, pause, + adjust_price, ret_index, chunksize, 'google', name) def get_data_fred(name, start=dt.datetime(2010, 1, 1), @@ -435,7 +432,7 @@ def get_data_fred(name, start=dt.datetime(2010, 1, 1), fred_URL = "http://research.stlouisfed.org/fred2/series/" url = fred_URL + '%s' % name + '/downloaddata/%s' % name + '.csv' - with closing(urlopen(url)) as resp: + with urlopen(url) as resp: data = read_csv(resp, index_col=0, parse_dates=True, header=None, skiprows=1, names=["DATE", name], na_values='.') @@ -448,39 +445,39 @@ def get_data_fred(name, start=dt.datetime(2010, 1, 1), raise -def get_data_famafrench(name, start=None, end=None): - start, end = _sanitize_dates(start, end) - +def get_data_famafrench(name): # path of zip files - zipFileURL = "http://mba.tuck.dartmouth.edu/pages/faculty/ken.french/ftp/" + zip_file_url = ('http://mba.tuck.dartmouth.edu/pages/faculty/' + 'ken.french/ftp/') + zip_file_path = '{0}{1}.zip'.format(zip_file_url, name) - with closing(urlopen(zipFileURL + name + ".zip")) as url: + with urlopen(zip_file_path) as url: raw = url.read() with tempfile.TemporaryFile() as tmpf: tmpf.write(raw) - with closing(ZipFile(tmpf, 'r')) as zf: + with ZipFile(tmpf, 'r') as zf: data = zf.read(name + '.txt').splitlines() - file_edges = np.where(np.array([len(d) for d in data]) == 2)[0] + line_lengths = np.array(map(len, data)) + file_edges = np.where(line_lengths)[0] datasets = {} - for i in xrange(len(file_edges) - 1): - dataset = [d.split() for d in data[(file_edges[i] + 1): - file_edges[i + 1]]] + edges = itertools.izip(file_edges[:-1], file_edges[1:]) + for i, (left_edge, right_edge) in enumerate(edges): + dataset = [d.split() for d in data[left_edge:right_edge]] if len(dataset) > 10: - ncol = np.median(np.array([len(d) for d in dataset])) - header_index = np.where( - np.array([len(d) for d in dataset]) == (ncol - 1))[0][-1] + ncol_raw = np.array(map(len, dataset)) + ncol = np.median(ncol_raw) + header_index = np.where(ncol_raw == ncol - 1)[0][-1] header = dataset[header_index] + ds_header = dataset[header_index + 1:] # to ensure the header is unique - header = ['{0} {1}'.format(j + 1, header_j) for j, header_j in - enumerate(header)] - index = np.array( - [d[0] for d in dataset[(header_index + 1):]], dtype=int) - dataset = np.array( - [d[1:] for d in dataset[(header_index + 1):]], dtype=float) + header = ['{0} {1}'.format(*items) for items in enumerate(header, + start=1)] + index = np.fromiter((d[0] for d in ds_header), dtype=int) + dataset = np.fromiter((d[1:] for d in ds_header), dtype=float) datasets[i] = DataFrame(dataset, index, columns=header) return datasets @@ -490,7 +487,7 @@ def get_data_famafrench(name, start=None, end=None): CUR_YEAR = dt.datetime.now().year -def _unpack(row, kind='td'): +def _unpack(row, kind): els = row.xpath('.//%s' % kind) return [val.text_content() for val in els] @@ -498,7 +495,7 @@ def _unpack(row, kind='td'): def _parse_options_data(table): rows = table.xpath('.//tr') header = _unpack(rows[0], kind='th') - data = map(_unpack, rows[1:]) + data = [_unpack(row, kind='td') for row in rows[1:]] # Use ',' as a thousands separator as we're pulling from the US site. return TextParser(data, names=header, na_values=['N/A'], thousands=',').get_chunk() @@ -615,13 +612,18 @@ def _get_option_data(self, month, year, expiry, table_loc, name): from lxml.html import parse except ImportError: raise ImportError("Please install lxml if you want to use the " - "{0} class".format(self.__class__.__name__)) + "{0!r} class".format(self.__class__.__name__)) try: - tables = parse(url).xpath('.//table') - except (AttributeError, IOError): - raise IndexError("Table location {0} invalid, unable to parse " - "tables".format(table_loc)) + doc = parse(url) + except _network_error_classes: + raise RemoteDataError("Unable to parse tables from URL " + "{0!r}".format(url)) else: + root = doc.getroot() + if root is None: + raise RemoteDataError("Parsed URL {0!r} has no root" + "element".format(url)) + tables = root.xpath('.//table') ntables = len(tables) if table_loc - 1 > ntables: raise IndexError("Table location {0} invalid, {1} tables" @@ -758,7 +760,7 @@ def get_near_stock_price(self, above_below=2, call=True, put=False, chop = df[get_range].dropna() chop.reset_index(inplace=True) data[nam] = chop - return [data[nam] for nam in sorted(to_ret)] + return [data[nam] for nam in to_ret] def _try_parse_dates(self, year, month, expiry): if year is not None or month is not None: @@ -852,7 +854,7 @@ def get_forward_data(self, months, call=True, put=False, near=False, else: all_data = concat([all_data, frame]) data[name] = all_data - ret = [data[k] for k in sorted(data.keys())] + ret = [data[k] for k in to_ret] if len(ret) == 1: return ret.pop() if len(ret) != 2: diff --git a/pandas/io/tests/test_cparser.py b/pandas/io/tests/test_cparser.py index 71c0367cf5da3..7fa8d06f48ea3 100644 --- a/pandas/io/tests/test_cparser.py +++ b/pandas/io/tests/test_cparser.py @@ -34,7 +34,7 @@ class TestCParser(unittest.TestCase): def setUp(self): - self.dirpath = tm.get_data_path('/') + self.dirpath = tm.get_data_path() self.csv1 = os.path.join(self.dirpath, 'test1.csv') self.csv2 = os.path.join(self.dirpath, 'test2.csv') self.xls1 = os.path.join(self.dirpath, 'test.xls') diff --git a/pandas/io/tests/test_yahoo.py b/pandas/io/tests/test_data.py similarity index 61% rename from pandas/io/tests/test_yahoo.py rename to pandas/io/tests/test_data.py index d6b65e7379d0a..2f4185154b8e6 100644 --- a/pandas/io/tests/test_yahoo.py +++ b/pandas/io/tests/test_data.py @@ -1,16 +1,88 @@ import unittest import warnings import nose +from nose.tools import assert_equal from datetime import datetime -import pandas as pd import numpy as np -import pandas.io.data as web -from pandas.util.testing import (network, assert_series_equal, - assert_produces_warning, assert_frame_equal) +import pandas as pd +from pandas import DataFrame +from pandas.io import data as web +from pandas.io.data import DataReader, SymbolWarning +from pandas.util.testing import (assert_series_equal, assert_produces_warning, + assert_frame_equal, network) from numpy.testing import assert_array_equal +def assert_n_failed_equals_n_null_columns(wngs, obj, cls=SymbolWarning): + all_nan_cols = pd.Series(dict((k, pd.isnull(v).all()) for k, v in + obj.iteritems())) + n_all_nan_cols = all_nan_cols.sum() + valid_warnings = pd.Series([wng for wng in wngs if isinstance(wng, cls)]) + assert_equal(len(valid_warnings), n_all_nan_cols) + failed_symbols = all_nan_cols[all_nan_cols].index + msgs = valid_warnings.map(lambda x: x.message) + assert msgs.str.contains('|'.join(failed_symbols)).all() + + +class TestGoogle(unittest.TestCase): + + @network + def test_google(self): + # asserts that google is minimally working and that it throws + # an exception when DataReader can't get a 200 response from + # google + start = datetime(2010, 1, 1) + end = datetime(2013, 01, 27) + + self.assertEquals( + web.DataReader("F", 'google', start, end)['Close'][-1], + 13.68) + + self.assertRaises(Exception, web.DataReader, "NON EXISTENT TICKER", + 'google', start, end) + + @network + def test_get_quote_fails(self): + self.assertRaises(NotImplementedError, web.get_quote_google, + pd.Series(['GOOG', 'AAPL', 'GOOG'])) + + @network + def test_get_goog_volume(self): + df = web.get_data_google('GOOG') + self.assertEqual(df.Volume.ix['OCT-08-2010'], 2863473) + + @network + def test_get_multi1(self): + sl = ['AAPL', 'AMZN', 'GOOG'] + pan = web.get_data_google(sl, '2012') + + def testit(): + ts = pan.Close.GOOG.index[pan.Close.AAPL > pan.Close.GOOG] + self.assertEquals(ts[0].dayofyear, 96) + + if (hasattr(pan, 'Close') and hasattr(pan.Close, 'GOOG') and + hasattr(pan.Close, 'AAPL')): + testit() + else: + self.assertRaises(AttributeError, testit) + + @network + def test_get_multi2(self): + with warnings.catch_warnings(record=True) as w: + pan = web.get_data_google(['GE', 'MSFT', 'INTC'], 'JAN-01-12', + 'JAN-31-12') + result = pan.Close.ix['01-18-12'] + assert_n_failed_equals_n_null_columns(w, result) + + # sanity checking + + assert np.issubdtype(result.dtype, np.floating) + result = pan.Open.ix['Jan-15-12':'Jan-20-12'] + self.assertEqual((4, 3), result.shape) + assert_n_failed_equals_n_null_columns(w, result) + + class TestYahoo(unittest.TestCase): @classmethod def setUpClass(cls): @@ -111,7 +183,7 @@ def test_get_data_multiple_symbols_two_dates(self): [ 19.03, 28.16, 25.52], [ 18.81, 28.82, 25.87]]) result = pan.Open.ix['Jan-15-12':'Jan-20-12'] - assert_array_equal(np.array(expected).shape, result.shape) + self.assertEqual(expected.shape, result.shape) @network def test_get_date_ret_index(self): @@ -251,6 +323,82 @@ def test_get_put_data_warning(self): warnings.warn("IndexError thrown no tables found") +class TestDataReader(unittest.TestCase): + @network + def test_read_yahoo(self): + gs = DataReader("GS", "yahoo") + assert isinstance(gs, DataFrame) + + @network + def test_read_google(self): + gs = DataReader("GS", "google") + assert isinstance(gs, DataFrame) + + @network + def test_read_fred(self): + vix = DataReader("VIXCLS", "fred") + assert isinstance(vix, DataFrame) + + @network + def test_read_famafrench(self): + for name in ("F-F_Research_Data_Factors", + "F-F_Research_Data_Factors_weekly", "6_Portfolios_2x3", + "F-F_ST_Reversal_Factor"): + ff = DataReader(name, "famafrench") + assert isinstance(ff, dict) + + +class TestFred(unittest.TestCase): + @network + def test_fred(self): + """ + Throws an exception when DataReader can't get a 200 response from + FRED. + """ + start = datetime(2010, 1, 1) + end = datetime(2013, 01, 27) + + self.assertEquals( + web.DataReader("GDP", "fred", start, end)['GDP'].tail(1), + 15984.1) + + self.assertRaises(Exception, web.DataReader, "NON EXISTENT SERIES", + 'fred', start, end) + + @network + def test_fred_nan(self): + start = datetime(2010, 1, 1) + end = datetime(2013, 01, 27) + df = web.DataReader("DFII5", "fred", start, end) + assert pd.isnull(df.ix['2010-01-01']) + + @network + def test_fred_parts(self): + start = datetime(2010, 1, 1) + end = datetime(2013, 01, 27) + df = web.get_data_fred("CPIAUCSL", start, end) + self.assertEqual(df.ix['2010-05-01'], 217.23) + + t = df.CPIAUCSL.values + assert np.issubdtype(t.dtype, np.floating) + self.assertEqual(t.shape, (37,)) + + @network + def test_fred_part2(self): + expected = [[576.7], + [962.9], + [684.7], + [848.3], + [933.3]] + result = web.get_data_fred("A09024USA144NNBR", start="1915").ix[:5] + assert_array_equal(result.values, np.array(expected)) + + @network + def test_invalid_series(self): + name = "NOT A REAL SERIES" + self.assertRaises(Exception, web.get_data_fred, name) + + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False) diff --git a/pandas/io/tests/test_data_reader.py b/pandas/io/tests/test_data_reader.py deleted file mode 100644 index 129e35921335c..0000000000000 --- a/pandas/io/tests/test_data_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import unittest - -from pandas.core.generic import PandasObject -from pandas.io.data import DataReader -from pandas.util.testing import network - - -class TestDataReader(unittest.TestCase): - @network - def test_read_yahoo(self): - gs = DataReader("GS", "yahoo") - assert isinstance(gs, PandasObject) - - @network - def test_read_google(self): - gs = DataReader("GS", "google") - assert isinstance(gs, PandasObject) - - @network - def test_read_fred(self): - vix = DataReader("VIXCLS", "fred") - assert isinstance(vix, PandasObject) - - @network - def test_read_famafrench(self): - for name in ("F-F_Research_Data_Factors", - "F-F_Research_Data_Factors_weekly", "6_Portfolios_2x3", - "F-F_ST_Reversal_Factor"): - ff = DataReader(name, "famafrench") - assert isinstance(ff, dict) diff --git a/pandas/io/tests/test_fred.py b/pandas/io/tests/test_fred.py deleted file mode 100644 index e06f8f91e82a7..0000000000000 --- a/pandas/io/tests/test_fred.py +++ /dev/null @@ -1,65 +0,0 @@ -import unittest -import nose -from datetime import datetime - -import pandas as pd -import numpy as np -import pandas.io.data as web -from pandas.util.testing import network -from numpy.testing import assert_array_equal - - -class TestFred(unittest.TestCase): - @network - def test_fred(self): - """ - Throws an exception when DataReader can't get a 200 response from - FRED. - """ - start = datetime(2010, 1, 1) - end = datetime(2013, 01, 27) - - self.assertEquals( - web.DataReader("GDP", "fred", start, end)['GDP'].tail(1), - 15984.1) - - self.assertRaises(Exception, web.DataReader, "NON EXISTENT SERIES", - 'fred', start, end) - - @network - def test_fred_nan(self): - start = datetime(2010, 1, 1) - end = datetime(2013, 01, 27) - df = web.DataReader("DFII5", "fred", start, end) - assert pd.isnull(df.ix['2010-01-01']) - - @network - def test_fred_parts(self): - start = datetime(2010, 1, 1) - end = datetime(2013, 01, 27) - df = web.get_data_fred("CPIAUCSL", start, end) - self.assertEqual(df.ix['2010-05-01'], 217.23) - - t = df.CPIAUCSL.values - assert np.issubdtype(t.dtype, np.floating) - self.assertEqual(t.shape, (37,)) - - @network - def test_fred_part2(self): - expected = [[576.7], - [962.9], - [684.7], - [848.3], - [933.3]] - result = web.get_data_fred("A09024USA144NNBR", start="1915").ix[:5] - assert_array_equal(result.values, np.array(expected)) - - @network - def test_invalid_series(self): - name = "NOT A REAL SERIES" - self.assertRaises(Exception, web.get_data_fred, name) - - -if __name__ == '__main__': - nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], - exit=False) diff --git a/pandas/io/tests/test_google.py b/pandas/io/tests/test_google.py deleted file mode 100644 index 65ae20fb5b505..0000000000000 --- a/pandas/io/tests/test_google.py +++ /dev/null @@ -1,73 +0,0 @@ -import unittest -import nose -from datetime import datetime - -import numpy as np -import pandas as pd -import pandas.io.data as web -from pandas.util.testing import network, with_connectivity_check - - -class TestGoogle(unittest.TestCase): - - @network - def test_google(self): - # asserts that google is minimally working and that it throws - # an exception when DataReader can't get a 200 response from - # google - start = datetime(2010, 1, 1) - end = datetime(2013, 01, 27) - - self.assertEquals( - web.DataReader("F", 'google', start, end)['Close'][-1], - 13.68) - - self.assertRaises(Exception, web.DataReader, "NON EXISTENT TICKER", - 'google', start, end) - - @network - def test_get_quote_fails(self): - self.assertRaises(NotImplementedError, web.get_quote_google, - pd.Series(['GOOG', 'AAPL', 'GOOG'])) - - @network - def test_get_goog_volume(self): - df = web.get_data_google('GOOG') - self.assertEqual(df.Volume.ix['OCT-08-2010'], 2863473) - - @network - def test_get_multi1(self): - sl = ['AAPL', 'AMZN', 'GOOG'] - pan = web.get_data_google(sl, '2012') - - def testit(): - ts = pan.Close.GOOG.index[pan.Close.AAPL > pan.Close.GOOG] - self.assertEquals(ts[0].dayofyear, 96) - - if (hasattr(pan, 'Close') and hasattr(pan.Close, 'GOOG') and - hasattr(pan.Close, 'AAPL')): - testit() - else: - self.assertRaises(AttributeError, testit) - - @network - def test_get_multi2(self): - pan = web.get_data_google(['GE', 'MSFT', 'INTC'], 'JAN-01-12', - 'JAN-31-12') - result = pan.Close.ix['01-18-12'] - self.assertEqual(len(result), 3) - - # sanity checking - assert np.issubdtype(result.dtype, np.floating) - - expected = np.array([[ 18.99, 28.4 , 25.18], - [ 18.58, 28.31, 25.13], - [ 19.03, 28.16, 25.52], - [ 18.81, 28.82, 25.87]]) - result = pan.Open.ix['Jan-15-12':'Jan-20-12'] - self.assertEqual(np.array(expected).shape, result.shape) - - -if __name__ == '__main__': - nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], - exit=False) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index c871e573719b9..47bde4ecb32a7 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -2,18 +2,20 @@ # pylint: disable-msg=W0402 -from datetime import datetime -from functools import wraps import random import string import sys import tempfile import warnings +import inspect +import os -from contextlib import contextmanager # contextlib is available since 2.5 - +from datetime import datetime +from functools import wraps +from contextlib import contextmanager, closing +from httplib import HTTPException +from urllib2 import urlopen from distutils.version import LooseVersion -import urllib2 from numpy.random import randn import numpy as np @@ -29,6 +31,8 @@ from pandas.tseries.index import DatetimeIndex from pandas.tseries.period import PeriodIndex +from pandas.io.common import urlopen + Index = index.Index MultiIndex = index.MultiIndex Series = series.Series @@ -81,7 +85,6 @@ def set_trace(): #------------------------------------------------------------------------------ # contextmanager to ensure the file cleanup -from contextlib import contextmanager @contextmanager def ensure_clean(filename = None): # if we are not passed a filename, generate a temporary @@ -91,30 +94,23 @@ def ensure_clean(filename = None): try: yield filename finally: - import os try: os.remove(filename) except: pass -def get_data_path(f = None): - """ return the path of a data file, these are relative to the current test dir """ - - if f is None: - f = '' - import inspect, os +def get_data_path(f=''): + """Return the path of a data file, these are relative to the current test + directory. + """ # get our callers file - frame,filename,line_number,function_name,lines,index = \ - inspect.getouterframes(inspect.currentframe())[1] - + _, filename, _, _, _, _ = inspect.getouterframes(inspect.currentframe())[1] base_dir = os.path.abspath(os.path.dirname(filename)) - return os.path.join(base_dir, 'data/%s' % f) + return os.path.join(base_dir, 'data', f) #------------------------------------------------------------------------------ # Comparators - - def equalContents(arr1, arr2): """Checks if the set of unique elements of arr1 and arr2 are equivalent. """ @@ -692,9 +688,11 @@ def dec(f): return wrapper +_network_error_classes = IOError, HTTPException + @optional_args def network(t, raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, - error_classes=(IOError,)): + error_classes=_network_error_classes): """ Label a test as requiring network connection and skip test if it encounters a ``URLError``. @@ -727,6 +725,7 @@ def network(t, raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, >>> from pandas.util.testing import network >>> import urllib2 + >>> import nose >>> @network ... def test_network(): ... urllib2.urlopen("rabbit://bonanza.com") @@ -770,12 +769,25 @@ def network_wrapper(*args, **kwargs): return network_wrapper -def can_connect(url): - """tries to connect to the given url. True if succeeds, False if IOError raised""" +def can_connect(url, error_classes=_network_error_classes): + """Try to connect to the given url. True if succeeds, False if IOError + raised + + Parameters + ---------- + url : basestring + The URL to try to connect to + + Returns + ------- + connectable : bool + Return True if no IOError (unable to connect) or URLError (bad url) was + raised + """ try: - with closing(urllib2.urlopen(url)) as resp: + with urlopen(url): pass - except IOError: + except error_classes: return False else: return True @@ -783,8 +795,9 @@ def can_connect(url): @optional_args def with_connectivity_check(t, url="http://www.google.com", - raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, check_before_test=False, - error_classes=IOError): + raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, + check_before_test=False, + error_classes=_network_error_classes): """ Label a test as requiring network connection and, if an error is encountered, only raise if it does not find a network connection. @@ -811,7 +824,10 @@ def with_connectivity_check(t, url="http://www.google.com", error classes to ignore. If not in ``error_classes``, raises the error. defaults to IOError. Be careful about changing the error classes here. - NOTE: ``raise_on_error`` supercedes ``check_before_test`` + Notes + ----- + * ``raise_on_error`` supercedes ``check_before_test`` + Returns ------- t : callable @@ -846,12 +862,12 @@ def with_connectivity_check(t, url="http://www.google.com", @wraps(t) def wrapper(*args, **kwargs): if check_before_test and not raise_on_error: - if not can_connect(url): + if not can_connect(url, error_classes): raise SkipTest try: return t(*args, **kwargs) except error_classes as e: - if raise_on_error or can_connect(url): + if raise_on_error or can_connect(url, error_classes): raise else: raise SkipTest("Skipping test due to lack of connectivity"