diff --git a/pandas/io/data.py b/pandas/io/data.py index b0ee77f11a0a7..278fc2fc6dd4d 100644 --- a/pandas/io/data.py +++ b/pandas/io/data.py @@ -10,12 +10,12 @@ import datetime as dt import urllib import time +from collections import defaultdict from contextlib import closing from urllib2 import urlopen from zipfile import ZipFile from pandas.util.py3compat import StringIO, bytes_to_str - from pandas import Panel, DataFrame, Series, read_csv, concat from pandas.io.parsers import TextParser @@ -56,17 +56,17 @@ def DataReader(name, data_source=None, start=None, end=None, """ start, end = _sanitize_dates(start, end) - if(data_source == "yahoo"): + if data_source == "yahoo": return get_data_yahoo(symbols=name, start=start, end=end, adjust_price=False, chunk=25, retry_count=retry_count, pause=pause) - elif(data_source == "google"): + elif data_source == "google": return get_data_google(symbols=name, start=start, end=end, - adjust_price=False, chunk=25, - retry_count=retry_count, pause=pause) - elif(data_source == "fred"): + adjust_price=False, chunk=25, + retry_count=retry_count, pause=pause) + elif data_source == "fred": return get_data_fred(name=name, start=start, end=end) - elif(data_source == "famafrench"): + elif data_source == "famafrench": return get_data_famafrench(name=name) @@ -94,21 +94,21 @@ def get_quote_yahoo(symbols): Returns a DataFrame """ - if isinstance(symbols, str): + if isinstance(symbols, basestring): sym_list = symbols elif not isinstance(symbols, Series): - symbols = Series(symbols) - sym_list = str.join('+', symbols) + symbols = Series(symbols) + sym_list = '+'.join(symbols) else: - sym_list = str.join('+', symbols) + sym_list = '+'.join(symbols) # for codes see: http://www.gummy-stuff.org/Yahoo-data.htm codes = {'symbol': 's', 'last': 'l1', 'change_pct': 'p2', 'PE': 'r', 'time': 't1', 'short_ratio': 's7'} - request = str.join('', codes.values()) # code request string + request = ''.join(codes.itervalues()) # code request string header = codes.keys() - data = dict(zip(codes.keys(), [[] for i in range(len(codes))])) + data = defaultdict(list) url_str = 'http://finance.yahoo.com/d/quotes.csv?s=%s&f=%s' % (sym_list, request) @@ -120,14 +120,15 @@ def get_quote_yahoo(symbols): fields = line.decode('utf-8').strip().split(',') for i, field in enumerate(fields): if field[-2:] == '%"': - data[header[i]].append(float(field.strip('"%'))) + v = float(field.strip('"%')) elif field[0] == '"': - data[header[i]].append(field.strip('"')) + v = field.strip('"') else: try: - data[header[i]].append(float(field)) + v = float(field) except ValueError: - data[header[i]].append(np.nan) + v = np.nan + data[header[i]].append(v) idx = data.pop('symbol') @@ -137,18 +138,15 @@ def get_quote_yahoo(symbols): def get_quote_google(symbols): raise NotImplementedError("Google Finance doesn't have this functionality") -def _get_hist_yahoo(sym=None, start=None, end=None, retry_count=3, - pause=0.001, **kwargs): + +def _get_hist_yahoo(sym, start=None, end=None, retry_count=3, pause=0.001, + **kwargs): """ Get historical data for the given name from yahoo. Date format is datetime Returns a DataFrame. """ - if(sym is None): - warnings.warn("Need to provide a name.") - return None - start, end = _sanitize_dates(start, end) yahoo_URL = 'http://ichart.yahoo.com/table.csv?' @@ -179,22 +177,18 @@ def _get_hist_yahoo(sym=None, start=None, end=None, retry_count=3, time.sleep(pause) - raise Exception("after %d tries, Yahoo did not " - "return a 200 for url %s" % (pause, url)) + raise IOError("after %d tries, Yahoo did not " + "return a 200 for url %r" % (retry_count, url)) -def _get_hist_google(sym=None, start=None, end=None, retry_count=3, - pause=0.001, **kwargs): +def _get_hist_google(sym, start=None, end=None, retry_count=3, pause=0.001, + **kwargs): """ Get historical data for the given name from google. Date format is datetime Returns a DataFrame. """ - if(sym is None): - warnings.warn("Need to provide a name.") - return None - start, end = _sanitize_dates(start, end) google_URL = 'http://www.google.com/finance/historical?' @@ -208,16 +202,15 @@ def _get_hist_google(sym=None, start=None, end=None, retry_count=3, for _ in xrange(retry_count): with closing(urlopen(url)) as resp: if resp.code == 200: - lines = resp.read() - rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0, + rs = read_csv(StringIO(bytes_to_str(resp.read())), index_col=0, parse_dates=True)[::-1] return rs time.sleep(pause) - raise Exception("after %d tries, Google did not " - "return a 200 for url %s" % (pause, url)) + raise IOError("after %d tries, Google did not " + "return a 200 for url %s" % (retry_count, url)) def _adjust_prices(hist_data, price_list=['Open', 'High', 'Low', 'Close']): @@ -244,9 +237,9 @@ def _calc_return_index(price_df): mask = ~df.ix[1].isnull() & df.ix[0].isnull() df.ix[0][mask] = 1 - #Check for first stock listings after starting date of index in ret_index - #If True, find first_valid_index and set previous entry to 1. - if(~mask).any(): + # Check for first stock listings after starting date of index in ret_index + # If True, find first_valid_index and set previous entry to 1. + if (~mask).any(): for sym in mask.index[~mask]: tstamp = df[sym].first_valid_index() t_idx = df.index.get_loc(tstamp) - 1 @@ -278,10 +271,10 @@ def get_components_yahoo(idx_sym): idx_df : DataFrame """ stats = 'snx' - #URL of form: - #http://download.finance.yahoo.com/d/quotes.csv?s=@%5EIXIC&f=snxl1d1t1c1ohgv - url = 'http://download.finance.yahoo.com/d/quotes.csv?s={0}&f={1}' \ - '&e=.csv&h={2}' + # URL of form: + # http://download.finance.yahoo.com/d/quotes.csv?s=@%5EIXIC&f=snxl1d1t1c1ohgv + url = ('http://download.finance.yahoo.com/d/quotes.csv?s={0}&f={1}' + '&e=.csv&h={2}') idx_mod = idx_sym.replace('^', '@%5E') url_str = url.format(idx_mod, stats, 1) @@ -310,9 +303,22 @@ def get_components_yahoo(idx_sym): return idx_df -def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, pause=0.001, - adjust_price=False, ret_index=False, chunksize=25, - **kwargs): +def _dl_mult_symbols(symbols, start, end, chunksize, pause, method, **kwargs): + stocks = {} + for sym_group in _in_chunks(symbols, chunksize): + for sym in sym_group: + try: + stocks[sym] = method(sym, start=start, end=end, pause=pause, + **kwargs) + except IOError: + warnings.warn('ERROR with symbol: {0}, skipping.'.format(sym)) + + return Panel(stocks).swapaxes('items', 'minor') + + +def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, + pause=0.001, adjust_price=False, ret_index=False, + chunksize=25, **kwargs): """ Returns DataFrame/Panel of historical stock prices from symbols, over date range, start to end. To avoid being penalized by Yahoo! Finance servers, @@ -334,8 +340,8 @@ def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, pause=0.00 Time, in seconds, to pause between consecutive queries of chunks. If single value given for symbol, represents the pause between retries. adjust_price : bool, default False - If True, adjusts all prices in hist_data ('Open', 'High', 'Low', 'Close') - based on 'Adj Close' price. Adds 'Adj_Ratio' column and drops + If True, adjusts all prices in hist_data ('Open', 'High', 'Low', + 'Close') based on 'Adj Close' price. Adds 'Adj_Ratio' column and drops 'Adj Close'. ret_index : bool, default False If True, includes a simple return index 'Ret_Index' in hist_data. @@ -346,49 +352,30 @@ def get_data_yahoo(symbols=None, start=None, end=None, retry_count=3, pause=0.00 ------- hist_data : DataFrame (str) or Panel (array-like object, DataFrame) """ - - def dl_mult_symbols(symbols): - stocks = {} - for sym_group in _in_chunks(symbols, chunksize): - for sym in sym_group: - try: - stocks[sym] = _get_hist_yahoo(sym, start=start, - end=end, **kwargs) - except: - warnings.warn('Error with sym: ' + sym + '... skipping.') - - time.sleep(pause) - - return Panel(stocks).swapaxes('items', 'minor') - if 'name' in kwargs: - warnings.warn("Arg 'name' is deprecated, please use 'symbols' instead.", - FutureWarning) + warnings.warn("Arg 'name' is deprecated, please use 'symbols' " + "instead.", FutureWarning) symbols = kwargs['name'] - #If a single symbol, (e.g., 'GOOG') - if isinstance(symbols, (str, int)): - sym = symbols - hist_data = _get_hist_yahoo(sym, start=start, end=end) - #Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) + # If a single symbol, (e.g., 'GOOG') + if isinstance(symbols, (basestring, int)): + hist_data = _get_hist_yahoo(symbols, start=start, end=end) + # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) elif isinstance(symbols, DataFrame): - try: - hist_data = dl_mult_symbols(Series(symbols.index)) - except ValueError: - raise - else: #Guess a Series - try: - hist_data = dl_mult_symbols(symbols) - except TypeError: - hist_data = dl_mult_symbols(Series(symbols)) + hist_data = _dl_mult_symbols(symbols.index, start, end, chunksize, + pause, _get_hist_yahoo, **kwargs) + else: + hist_data = _dl_mult_symbols(symbols, start, end, chunksize, pause, + _get_hist_yahoo, **kwargs) - if(ret_index): + if ret_index: hist_data['Ret_Index'] = _calc_return_index(hist_data['Adj Close']) - if(adjust_price): + if adjust_price: hist_data = _adjust_prices(hist_data) return hist_data + def get_data_google(symbols=None, start=None, end=None, retry_count=3, pause=0.001, chunksize=25, **kwargs): """ @@ -418,45 +405,24 @@ def get_data_google(symbols=None, start=None, end=None, retry_count=3, ------- hist_data : DataFrame (str) or Panel (array-like object, DataFrame) """ - - def dl_mult_symbols(symbols): - stocks = {} - for sym_group in _in_chunks(symbols, chunksize): - for sym in sym_group: - try: - stocks[sym] = _get_hist_google(sym, start=start, - end=end, **kwargs) - except: - warnings.warn('Error with sym: ' + sym + '... skipping.') - - time.sleep(pause) - - return Panel(stocks).swapaxes('items', 'minor') - if 'name' in kwargs: - warnings.warn("Arg 'name' is deprecated, please use 'symbols' instead.", - FutureWarning) + warnings.warn("Arg 'name' is deprecated, please use 'symbols' " + "instead.", FutureWarning) symbols = kwargs['name'] - #If a single symbol, (e.g., 'GOOG') - if isinstance(symbols, (str, int)): - sym = symbols - hist_data = _get_hist_google(sym, start=start, end=end) - #Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) + # If a single symbol, (e.g., 'GOOG') + if isinstance(symbols, (basestring, int)): + return _get_hist_google(symbols, start=start, end=end) + # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) elif isinstance(symbols, DataFrame): - try: - hist_data = dl_mult_symbols(Series(symbols.index)) - except ValueError: - raise - else: #Guess a Series - try: - hist_data = dl_mult_symbols(symbols) - except TypeError: - hist_data = dl_mult_symbols(Series(symbols)) + symbs = symbols.index + else: # Guess a Series + symbs = symbols + return _dl_mult_symbols(symbs, start, end, chunksize, pause, + _get_hist_google, **kwargs) - return hist_data -def get_data_fred(name=None, start=dt.datetime(2010, 1, 1), +def get_data_fred(name, start=dt.datetime(2010, 1, 1), end=dt.datetime.today()): """ Get data for the given name from the St. Louis FED (FRED). @@ -466,10 +432,6 @@ def get_data_fred(name=None, start=dt.datetime(2010, 1, 1), """ start, end = _sanitize_dates(start, end) - if(name is None): - print ("Need to provide a name") - return None - fred_URL = "http://research.stlouisfed.org/fred2/series/" url = fred_URL + '%s' % name + '/downloaddata/%s' % name + '.csv' @@ -481,11 +443,10 @@ def get_data_fred(name=None, start=dt.datetime(2010, 1, 1), return data.truncate(start, end) except KeyError: if data.ix[3].name[7:12] == 'Error': - raise Exception("Failed to get the data. " - "Check that {0!r} is valid FRED " - "series.".format(name)) - else: - raise + raise IOError("Failed to get the data. Check that {0!r} is " + "a valid FRED series.".format(name)) + raise + def get_data_famafrench(name, start=None, end=None): start, end = _sanitize_dates(start, end) @@ -505,16 +466,17 @@ def get_data_famafrench(name, start=None, end=None): file_edges = np.where(np.array([len(d) for d in data]) == 2)[0] datasets = {} - for i in range(len(file_edges) - 1): + for i in xrange(len(file_edges) - 1): dataset = [d.split() for d in data[(file_edges[i] + 1): file_edges[i + 1]]] - if(len(dataset) > 10): + if len(dataset) > 10: ncol = np.median(np.array([len(d) for d in dataset])) header_index = np.where( np.array([len(d) for d in dataset]) == (ncol - 1))[0][-1] header = dataset[header_index] # to ensure the header is unique - header = [str(j + 1) + " " + header[j] for j in range(len(header))] + header = ['{0} {1}'.format(j + 1, header_j) for j, header_j in + enumerate(header)] index = np.array( [d[0] for d in dataset[(header_index + 1):]], dtype=int) dataset = np.array( @@ -524,24 +486,28 @@ def get_data_famafrench(name, start=None, end=None): return datasets # Items needed for options class -cur_month = dt.datetime.now().month -cur_year = dt.datetime.now().year +CUR_MONTH = dt.datetime.now().month +CUR_YEAR = dt.datetime.now().year def _unpack(row, kind='td'): - els = row.findall('.//%s' % kind) - return[val.text_content() for val in els] + els = row.xpath('.//%s' % kind) + return [val.text_content() for val in els] def _parse_options_data(table): - rows = table.findall('.//tr') + rows = table.xpath('.//tr') header = _unpack(rows[0], kind='th') - data = [_unpack(r) for r in rows[1:]] + data = map(_unpack, rows[1:]) # Use ',' as a thousands separator as we're pulling from the US site. return TextParser(data, names=header, na_values=['N/A'], thousands=',').get_chunk() +def _two_char_month(s): + return '{0:0>2}'.format(s) + + class Options(object): """ This class fetches call/put data for a given stock/expiry month. @@ -549,11 +515,11 @@ class Options(object): It is instantiated with a string representing the ticker symbol. The class has the following methods: - get_options_data:(month, year) - get_call_data:(month, year) - get_put_data: (month, year) + get_options:(month, year) + get_calls:(month, year) + get_puts: (month, year) get_near_stock_price(opt_frame, above_below) - get_forward_data(months, call, put) + get_forward(months, call, put) Examples -------- @@ -561,13 +527,13 @@ class Options(object): >>> aapl = Options('aapl', 'yahoo') # Fetch September 2012 call data - >>> calls = aapl.get_call_data(9, 2012) + >>> calls = aapl.get_calls(9, 2012) # Can now access aapl.calls instance variable >>> aapl.calls # Fetch September 2012 put data - >>> puts = aapl.get_put_data(9, 2012) + >>> puts = aapl.get_puts(9, 2012) # Can now access aapl.puts instance variable >>> aapl.puts @@ -580,15 +546,14 @@ class Options(object): ... call=True, put=True) """ - def __init__(self, symbol, data_source=None): """ Instantiates options_data with a ticker saved as symbol """ - self.symbol = str(symbol).upper() - if (data_source is None): - warnings.warn("Options(symbol) is deprecated, use Options(symbol, data_source) instead", - FutureWarning) + self.symbol = symbol.upper() + if data_source is None: + warnings.warn("Options(symbol) is deprecated, use Options(symbol," + " data_source) instead", FutureWarning) data_source = "yahoo" - if (data_source != "yahoo"): + if data_source != "yahoo": raise NotImplementedError("currently only yahoo supported") def get_options_data(self, month=None, year=None, expiry=None): @@ -617,7 +582,7 @@ def get_options_data(self, month=None, year=None, expiry=None): >>> aapl = Options('aapl', 'yahoo') # Create object >>> aapl.calls # will give an AttributeError - >>> aapl.get_options_data() # Get data and set ivars + >>> aapl.get_options() # Get data and set ivars >>> aapl.calls # Doesn't throw AttributeError Also note that aapl.calls and appl.puts will always be the calls @@ -627,45 +592,47 @@ def get_options_data(self, month=None, year=None, expiry=None): representations of the month and year for the expiry of the options. """ - year, month = self._try_parse_dates(year,month,expiry) + return [f(month, year, expiry) for f in (self.get_put_data, + self.get_call_data)] - from lxml.html import parse + def _get_option_data(self, month, year, expiry, table_loc, name): + year, month = self._try_parse_dates(year, month, expiry) - if month and year: # try to get specified month from yahoo finance - m1 = month if len(str(month)) == 2 else '0' + str(month) - m2 = month + url = 'http://finance.yahoo.com/q/op?s={sym}'.format(sym=self.symbol) - if m1 != cur_month and m2 != cur_month: # if this month use other url - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '&m=' + str(year) + '-' + str(m1)) + if month and year: # try to get specified month from yahoo finance + m1, m2 = _two_char_month(month), month + # if this month use other url + if m1 != CUR_MONTH and m2 != CUR_MONTH: + url += '&m={year}-{m1}'.format(year=year, m1=m1) else: - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') - + url += '+Options' else: # Default to current month - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') + url += '+Options' - parsed = parse(url) - doc = parsed.getroot() - tables = doc.findall('.//table') - calls = tables[9] - puts = tables[13] + try: + from lxml.html import parse + except ImportError: + raise ImportError("Please install lxml if you want to use the " + "{0} class".format(self.__class__.__name__)) + try: + tables = parse(url).xpath('.//table') + except (AttributeError, IOError): + raise IndexError("Table location {0} invalid, unable to parse " + "tables".format(table_loc)) + else: + ntables = len(tables) + if table_loc - 1 > ntables: + raise IndexError("Table location {0} invalid, {1} tables" + " found".format(table_loc, ntables)) - call_data = _parse_options_data(calls) - put_data = _parse_options_data(puts) + option_data = _parse_options_data(tables[table_loc]) if month: - c_name = 'calls' + str(m1) + str(year)[2:] - p_name = 'puts' + str(m1) + str(year)[2:] - self.__setattr__(c_name, call_data) - self.__setattr__(p_name, put_data) - else: - self.calls = call_data - self.calls = put_data - - return [call_data, put_data] + name += m1 + str(year)[-2:] + setattr(self, name, option_data) + return option_data def get_call_data(self, month=None, year=None, expiry=None): """ @@ -698,40 +665,7 @@ def get_call_data(self, month=None, year=None, expiry=None): repsectively, two digit representations of the month and year for the expiry of the options. """ - year, month = self._try_parse_dates(year,month,expiry) - - from lxml.html import parse - - if month and year: # try to get specified month from yahoo finance - m1 = month if len(str(month)) == 2 else '0' + str(month) - m2 = month - - if m1 != cur_month and m2 != cur_month: # if this month use other url - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '&m=' + str(year) + '-' + str(m1)) - - else: - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') - - else: # Default to current month - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') - - parsed = parse(url) - doc = parsed.getroot() - tables = doc.findall('.//table') - calls = tables[9] - - call_data = _parse_options_data(calls) - - if month: - name = 'calls' + str(m1) + str(year)[2:] - self.__setattr__(name, call_data) - else: - self.calls = call_data - - return call_data + return self._get_option_data(month, year, expiry, 9, 'calls') def get_put_data(self, month=None, year=None, expiry=None): """ @@ -766,40 +700,7 @@ def get_put_data(self, month=None, year=None, expiry=None): repsectively, two digit representations of the month and year for the expiry of the options. """ - year, month = self._try_parse_dates(year,month,expiry) - - from lxml.html import parse - - if month and year: # try to get specified month from yahoo finance - m1 = month if len(str(month)) == 2 else '0' + str(month) - m2 = month - - if m1 != cur_month and m2 != cur_month: # if this month use other url - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '&m=' + str(year) + '-' + str(m1)) - - else: - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') - - else: # Default to current month - url = str('http://finance.yahoo.com/q/op?s=' + self.symbol + - '+Options') - - parsed = parse(url) - doc = parsed.getroot() - tables = doc.findall('.//table') - puts = tables[13] - - put_data = _parse_options_data(puts) - - if month: - name = 'puts' + str(m1) + str(year)[2:] - self.__setattr__(name, put_data) - else: - self.puts = put_data - - return put_data + return self._get_option_data(month, year, expiry, 13, 'puts') def get_near_stock_price(self, above_below=2, call=True, put=False, month=None, year=None, expiry=None): @@ -831,68 +732,42 @@ def get_near_stock_price(self, above_below=2, call=True, put=False, desired. If there isn't data as far out as the user has asked for then """ - year, month = self._try_parse_dates(year,month,expiry) - + year, month = self._try_parse_dates(year, month, expiry) price = float(get_quote_yahoo([self.symbol])['last']) - if call: - try: - if month: - m1 = month if len(str(month)) == 2 else '0' + str(month) - name = 'calls' + str(m1) + str(year)[2:] - df_c = self.__getattribute__(name) - else: - df_c = self.calls - except AttributeError: - df_c = self.get_call_data(month, year) - - start_index = np.where(df_c['Strike'] > price)[0][0] - - get_range = range(start_index - above_below, - start_index + above_below + 1) + to_ret = Series({'calls': call, 'puts': put}) + to_ret = to_ret[to_ret].index - chop_call = df_c.ix[get_range, :] + data = {} - chop_call = chop_call.dropna(how='all') - chop_call = chop_call.reset_index() + for nam in to_ret: + if month: + m1 = _two_char_month(month) + name = nam + m1 + str(year)[2:] - if put: try: - if month: - m1 = month if len(str(month)) == 2 else '0' + str(month) - name = 'puts' + str(m1) + str(year)[2:] - df_p = self.__getattribute__(name) - else: - df_p = self.puts + df = getattr(self, name) except AttributeError: - df_p = self.get_put_data(month, year) + meth_name = 'get_{0}_data'.format(nam[:-1]) + df = getattr(self, meth_name)(month, year) - start_index = np.where(df_p.Strike > price)[0][0] + start_index = np.where(df['Strike'] > price)[0][0] - get_range = range(start_index - above_below, + get_range = slice(start_index - above_below, start_index + above_below + 1) - - chop_put = df_p.ix[get_range, :] - - chop_put = chop_put.dropna(how='all') - chop_put = chop_put.reset_index() - - if call and put: - return [chop_call, chop_put] - else: - if call: - return chop_call - else: - return chop_put + chop = df[get_range].dropna() + chop.reset_index(inplace=True) + data[nam] = chop + return [data[nam] for nam in sorted(to_ret)] def _try_parse_dates(self, year, month, expiry): if year is not None or month is not None: - warnings.warn("month, year arguments are deprecated, use expiry instead", - FutureWarning) + warnings.warn("month, year arguments are deprecated, use expiry" + " instead", FutureWarning) if expiry is not None: - year=expiry.year - month=expiry.month + year = expiry.year + month = expiry.month return year, month def get_forward_data(self, months, call=True, put=False, near=False, @@ -923,106 +798,63 @@ def get_forward_data(self, months, call=True, put=False, near=False, Returns ------- - all_calls: DataFrame - If asked for, a DataFrame containing call data from the current - month to the current month plus months. - - all_puts: DataFrame - If asked for, a DataFrame containing put data from the current - month to the current month plus months. + data : dict of str, DataFrame """ warnings.warn("get_forward_data() is deprecated", FutureWarning) - in_months = range(cur_month, cur_month + months + 1) - in_years = [cur_year] * (months + 1) + in_months = xrange(CUR_MONTH, CUR_MONTH + months + 1) + in_years = [CUR_YEAR] * (months + 1) # Figure out how many items in in_months go past 12 to_change = 0 - for i in range(months): + for i in xrange(months): if in_months[i] > 12: in_months[i] -= 12 to_change += 1 # Change the corresponding items in the in_years list. - for i in range(1, to_change + 1): + for i in xrange(1, to_change + 1): in_years[-i] += 1 - if call: - all_calls = DataFrame() - for mon in range(months): - m2 = in_months[mon] - y2 = in_years[mon] - try: # This catches cases when there isn't data for a month - if not near: - try: # Try to access the ivar if already instantiated - - m1 = m2 if len(str(m2)) == 2 else '0' + str(m2) - name = 'calls' + str(m1) + str(y2)[2:] - call_frame = self.__getattribute__(name) - except: - call_frame = self.get_call_data(in_months[mon], - in_years[mon]) - - else: - call_frame = self.get_near_stock_price(call=True, - put=False, - above_below=above_below, - month=m2, year=y2) - - tick = str(call_frame.Symbol[0]) - start = len(self.symbol) - year = tick[start: start + 2] - month = tick[start + 2: start + 4] - day = tick[start + 4: start + 6] - expiry = str(month + '-' + day + '-' + year) - call_frame['Expiry'] = expiry - if mon == 0: - all_calls = all_calls.join(call_frame, how='right') - else: - all_calls = concat([all_calls, call_frame]) - except: - pass - - if put: - all_puts = DataFrame() - for mon in range(months): + to_ret = Series({'calls': call, 'puts': put}) + to_ret = to_ret[to_ret].index + data = {} + + for name in to_ret: + all_data = DataFrame() + + for mon in xrange(months): m2 = in_months[mon] y2 = in_years[mon] - try: # This catches cases when there isn't data for a month - if not near: - try: # Try to access the ivar if already instantiated - m1 = m2 if len(str(m2)) == 2 else '0' + str(m2) - name = 'puts' + str(m1) + str(y2)[2:] - put_frame = self.__getattribute__(name) - except: - put_frame = self.get_call_data(in_months[mon], - in_years[mon]) - - else: - put_frame = self.get_near_stock_price(call=False, - put=True, - above_below=above_below, - month=m2, year=y2) - - # Add column with expiry data to this frame. - tick = str(put_frame.Symbol[0]) - start = len(self.symbol) - year = tick[start: start + 2] - month = tick[start + 2: start + 4] - day = tick[start + 4: start + 6] - expiry = str(month + '-' + day + '-' + year) - put_frame['Expiry'] = expiry - - if mon == 0: - all_puts = all_puts.join(put_frame, how='right') - else: - all_puts = concat([all_puts, put_frame]) - except: - pass - - if call and put: - return [all_calls, all_puts] - else: - if call: - return all_calls - else: - return all_puts + + if not near: + m1 = _two_char_month(m2) + nam = name + str(m1) + str(y2)[2:] + + try: # Try to access on the instance + frame = getattr(self, nam) + except AttributeError: + meth_name = 'get_{0}_data'.format(name[:-1]) + frame = getattr(self, meth_name)(m2, y2) + else: + frame = self.get_near_stock_price(call=call, put=put, + above_below=above_below, + month=m2, year=y2) + tick = str(frame.Symbol[0]) + start = len(self.symbol) + year = tick[start:start + 2] + month = tick[start + 2:start + 4] + day = tick[start + 4:start + 6] + expiry = month + '-' + day + '-' + year + frame['Expiry'] = expiry + + if not mon: + all_data = all_data.join(frame, how='right') + else: + all_data = concat([all_data, frame]) + data[name] = all_data + ret = [data[k] for k in sorted(data.keys())] + if len(ret) == 1: + return ret.pop() + if len(ret) != 2: + raise AssertionError("should be len 2") + return ret diff --git a/pandas/io/tests/test_yahoo.py b/pandas/io/tests/test_yahoo.py index 3d4252f99cbd5..d6b65e7379d0a 100644 --- a/pandas/io/tests/test_yahoo.py +++ b/pandas/io/tests/test_yahoo.py @@ -1,4 +1,5 @@ import unittest +import warnings import nose from datetime import datetime @@ -6,7 +7,7 @@ import numpy as np import pandas.io.data as web from pandas.util.testing import (network, assert_series_equal, - assert_produces_warning) + assert_produces_warning, assert_frame_equal) from numpy.testing import assert_array_equal @@ -37,10 +38,21 @@ def test_yahoo_fails(self): 'yahoo', start, end) @network - def test_get_quote(self): + def test_get_quote_series(self): df = web.get_quote_yahoo(pd.Series(['GOOG', 'AAPL', 'GOOG'])) assert_series_equal(df.ix[0], df.ix[2]) + @network + def test_get_quote_string(self): + df = web.get_quote_yahoo('GOOG') + df2 = web.get_quote_yahoo('GOOG') + assert_frame_equal(df, df2) + + @network + def test_get_quote_stringlist(self): + df = web.get_quote_yahoo(['GOOG', 'AAPL', 'GOOG']) + assert_series_equal(df.ix[0], df.ix[2]) + @network def test_get_components_dow_jones(self): df = web.get_components_yahoo('^DJI') #Dow Jones @@ -139,26 +151,42 @@ def tearDownClass(cls): @network def test_get_options_data(self): - calls, puts = self.aapl.get_options_data(expiry=self.expiry) - assert len(calls)>1 - assert len(puts)>1 + try: + calls, puts = self.aapl.get_options_data(expiry=self.expiry) + except IndexError: + warnings.warn("IndexError thrown no tables found") + else: + assert len(calls)>1 + assert len(puts)>1 @network def test_get_near_stock_price(self): - calls, puts = self.aapl.get_near_stock_price(call=True, put=True, - expiry=self.expiry) - self.assertEqual(len(calls), 5) - self.assertEqual(len(puts), 5) + try: + calls, puts = self.aapl.get_near_stock_price(call=True, put=True, + expiry=self.expiry) + except IndexError: + warnings.warn("IndexError thrown no tables found") + else: + self.assertEqual(len(calls), 5) + self.assertEqual(len(puts), 5) @network def test_get_call_data(self): - calls = self.aapl.get_call_data(expiry=self.expiry) - assert len(calls)>1 + try: + calls = self.aapl.get_call_data(expiry=self.expiry) + except IndexError: + warnings.warn("IndexError thrown no tables found") + else: + assert len(calls)>1 @network def test_get_put_data(self): - puts = self.aapl.get_put_data(expiry=self.expiry) - assert len(puts)>1 + try: + puts = self.aapl.get_put_data(expiry=self.expiry) + except IndexError: + warnings.warn("IndexError thrown no tables found") + else: + assert len(puts)>1 class TestOptionsWarnings(unittest.TestCase): @@ -169,7 +197,7 @@ def setUpClass(cls): except ImportError: raise nose.SkipTest - with assert_produces_warning(FutureWarning): + with assert_produces_warning(): cls.aapl = web.Options('aapl') today = datetime.today() @@ -185,30 +213,42 @@ def tearDownClass(cls): @network def test_get_options_data_warning(self): - with assert_produces_warning(FutureWarning): + with assert_produces_warning(): print('month: {0}, year: {1}'.format(self.month, self.year)) - self.aapl.get_options_data(month=self.month, year=self.year) + try: + self.aapl.get_options_data(month=self.month, year=self.year) + except IndexError: + warnings.warn("IndexError thrown no tables found") @network def test_get_near_stock_price_warning(self): - with assert_produces_warning(FutureWarning): + with assert_produces_warning(): print('month: {0}, year: {1}'.format(self.month, self.year)) - calls_near, puts_near = self.aapl.get_near_stock_price(call=True, - put=True, - month=self.month, - year=self.year) + try: + calls_near, puts_near = self.aapl.get_near_stock_price(call=True, + put=True, + month=self.month, + year=self.year) + except IndexError: + warnings.warn("IndexError thrown no tables found") @network def test_get_call_data_warning(self): - with assert_produces_warning(FutureWarning): + with assert_produces_warning(): print('month: {0}, year: {1}'.format(self.month, self.year)) - self.aapl.get_call_data(month=self.month, year=self.year) + try: + self.aapl.get_call_data(month=self.month, year=self.year) + except IndexError: + warnings.warn("IndexError thrown no tables found") @network def test_get_put_data_warning(self): - with assert_produces_warning(FutureWarning): + with assert_produces_warning(): print('month: {0}, year: {1}'.format(self.month, self.year)) - self.aapl.get_put_data(month=self.month, year=self.year) + try: + self.aapl.get_put_data(month=self.month, year=self.year) + except IndexError: + warnings.warn("IndexError thrown no tables found") if __name__ == '__main__':