-
-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Adding Scatterplot Matrix to FigureFactory #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f72beed
04c8e80
f7ba1c3
70c916d
68c4143
7d43513
89d28f3
c9d8bf7
6698fcd
3ba79c2
08745cd
14a4bfd
bbd02a5
a6a91c2
c3e7755
09d3c5d
bb48595
1957cf2
fbfe495
a87a07b
84fc4c4
f98da6a
f4764b4
e25d392
7a2dffa
b3bf96d
398b8d5
7afe967
718dd71
78608a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1442,6 +1442,376 @@ class FigureFactory(object): | |
more information and examples of a specific chart type. | ||
""" | ||
|
||
@staticmethod | ||
def create_scatterplotmatrix(df, useindex = False, index = " ", | ||
diagonal = "Scatter", symbol = 0, size = 6, | ||
height = " ", width = " ", jitter = 0, | ||
title = "Scatterplot Matrix"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We typically keep pretty lengthy docs for things like this This way, when users run |
||
|
||
from plotly import tools | ||
import plotly.plotly as py | ||
import plotly.graph_objs as go | ||
import pandas as pd | ||
import random as r | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, it's a good idea to import things at the top of a file, when possible, not dynamically like this. AFAIK, the only thing that needs to be imported down here is https://github.com/plotly/plotly.py/blob/master/plotly/tools.py#L1743-L1744 Additionally, we protect Can you first add a conditional check for Finally, I think we typically just Some of this stuff is a little confusing, let me know if you have questions. |
||
|
||
colors = ["rgb(31, 119, 180)", "rgb(255, 127, 14)", | ||
"rgb(44, 160, 44)", "rgb(214, 39, 40)", | ||
"rgb(148, 103, 189)", "rgb(140, 86, 75)", | ||
"rgb(227, 119, 194)", "rgb(127, 127, 127)", | ||
"rgb(188, 189, 34)", "rgb(23, 190, 207)"] | ||
|
||
matrix = [] | ||
col_names = [] | ||
index_vals = [] | ||
diagonal_choices = ["Scatter", "Histogram"] | ||
|
||
# Check if pandas dataframe | ||
if type(df) != pd.core.frame.DataFrame: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should try to prefer |
||
raise exceptions.PlotlyError("Dataframe not inputed. Please " | ||
"use a pandas dataframe to pro" | ||
"duce a scatterplot matrix.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great messages! |
||
|
||
# Check if dataframe is 1 row or less. | ||
if len(df.columns) <= 1: | ||
raise exceptions.PlotlyError("Dataframe has only one column. To " | ||
"use the scatterplot matrix, use at " | ||
"least 2 columns.") | ||
|
||
# Check that diagonal parameter is selected properly | ||
if diagonal not in diagonal_choices: | ||
raise exceptions.PlotlyError("Make sure diagonal is set to " | ||
"either Scatter or Histogram.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you condense all the input checks into a single private Here's an example Note that you might be able to reuse other validation methods in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, to ensure that these messages don't get out-of-sync with the actual variables, you should do something like:
|
||
|
||
|
||
if useindex == True: | ||
if index not in df: | ||
raise exceptions.PlotlyError("Make sure you set the index " | ||
"input variable to one of the " | ||
"column names of your matrix.") | ||
|
||
else: | ||
index_vals = df[index].values.tolist() | ||
|
||
# Make list of column names besides index | ||
for name in df: | ||
if name != index: | ||
col_names.append(name) | ||
|
||
# Populate matrix = [] | ||
for name in col_names: | ||
matrix.append(df[name].values.tolist()) | ||
|
||
# Check if Matrix Values are either all strings or numbers | ||
for vector in matrix: | ||
if ((type(vector[0]) == float) or (type(vector[0]) == int)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
for entry in vector: | ||
if not ((type(entry) == float) or (type(entry) == int)): | ||
raise exceptions.PlotlyError("Error in data" | ||
"frame. Make sure " | ||
"that all entries " | ||
"of each column are " | ||
"either numbers or " | ||
"strings.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a style thing, but this would be easier to read if written:
That goes for a bunch of these messages. |
||
|
||
if (type(vector[0]) == str): | ||
for entry in vector: | ||
if (type(entry) != str): | ||
raise exceptions.PlotlyError("Error in data" | ||
"frame. Make sure " | ||
"that all entries " | ||
"of each column are " | ||
"either numbers or " | ||
"strings.") | ||
|
||
# Check if index_vals are either all strings or numbers | ||
if ((type(index_vals[0]) == float) or (type(index_vals[0]) == int)): | ||
for entry in index_vals: | ||
if not ((type(entry) == float) or (type(entry) == int)): | ||
raise exceptions.PlotlyError("Error in data" | ||
"frame. Make sure " | ||
"that all entries " | ||
"of each column are " | ||
"either numbers or " | ||
"strings.") | ||
|
||
if (type(index_vals[0]) == str): | ||
for entry in index_vals: | ||
if (type(entry) != str): | ||
raise exceptions.PlotlyError("Error in data" | ||
"frame. Make sure " | ||
"that all entries " | ||
"of each column are " | ||
"either numbers or " | ||
"strings.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lots of duplication here, you could write two functions |
||
|
||
if useindex == False: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Always use |
||
# Make list of column names | ||
for name in df: | ||
col_names.append(name) | ||
|
||
# Populate matrix = [] with dataframe columns | ||
for name in col_names: | ||
matrix.append(df[name].values.tolist()) | ||
|
||
|
||
# Check if values in each column are either | ||
# all strings or all numbers | ||
|
||
# Matrix Check | ||
for vector in matrix: | ||
if ((type(vector[0]) == float) or (type(vector[0]) == int)): | ||
for entry in vector: | ||
if not ((type(entry) == float) or (type(entry) == int)): | ||
raise exceptions.PlotlyError("Error in data" | ||
"frame. Make sure " | ||
"that all entries " | ||
"of each column are " | ||
"either numbers or " | ||
"strings.") | ||
|
||
if (type(vector[0]) == str): | ||
for entry in vector: | ||
if (type(entry) != str): | ||
raise exceptions.PlotlyError("Error in dataframe. " | ||
"Make sure that all " | ||
"entries of each col" | ||
"umn are either " | ||
"numbers or strings.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, DRY |
||
|
||
# Main Code | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using a |
||
dim = len(matrix) | ||
trace_list = [] | ||
|
||
if index_vals == []: | ||
fig = tools.make_subplots(rows=dim, cols=dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice reuse! |
||
|
||
# Insert traces into trace_list | ||
for listy in matrix: | ||
for listx in matrix: | ||
if (listx == listy) and (diagonal == "Histogram"): | ||
trace = go.Histogram( | ||
x = listx, | ||
showlegend = False | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😄 lookin' like JS code with all those spaces. I think Python code typically suppresses such extraneous whitespace |
||
else: | ||
# Add Jitter | ||
if (jitter < 0.0) or (jitter > 1.0): | ||
raise exceptions.PlotlyError("Jitter must lie " | ||
"between 0 and 1.0 " | ||
"inclusive.") | ||
|
||
if type(listx[0]) != str: | ||
for j in range(len(listx)): | ||
listx[j] = listx[j] + jitter*r.uniform(-1,1) | ||
|
||
if type(listy[0]) != str: | ||
for j in range(len(listy)): | ||
listy[j] = listy[j] + jitter*r.uniform(-1,1) | ||
|
||
trace = go.Scatter( | ||
x = listx, | ||
y = listy, | ||
mode = "markers", | ||
marker = dict( | ||
symbol = symbol, | ||
size = size), | ||
showlegend = False, | ||
) | ||
trace_list.append(trace) | ||
|
||
# Create list of index values for the axes | ||
# eg. if dim = 3, then the indicies are [1, 2, 3] | ||
indicies = range(dim) | ||
indicies.remove(0) | ||
indicies.append(dim) | ||
j = 0 | ||
|
||
for y_index in indicies: | ||
for x_index in indicies: | ||
fig.append_trace(trace_list[j], y_index, x_index) | ||
j += 1 | ||
|
||
# Check if length of col_names and array match | ||
if len(col_names) != dim: | ||
raise exceptions.PlotlyError("The length of your variable_" | ||
"names list doesn't match the " | ||
"number of lists in your " | ||
"matrix. This means that both " | ||
"col_names and matrix " | ||
"must have the same " | ||
"dimension.") | ||
|
||
# Insert col_names into the figure | ||
for j in range(dim): | ||
xaxis_place = "xaxis" + str(dim*dim - dim + 1 + j) | ||
fig['layout'][xaxis_place].update(title = col_names[j]) | ||
|
||
for j in range(dim): | ||
yaxis_place = "yaxis" + str(1 + dim*j) | ||
fig['layout'][yaxis_place].update(title = col_names[j]) | ||
|
||
# Set height and width if not already selected by user | ||
if (height == " "): | ||
height = 400 + 200*(dim - 1) | ||
|
||
if (width == " "): | ||
width = 400 + 200*(dim - 1) | ||
|
||
fig['layout'].update( | ||
height = height, width = width, | ||
title = title, | ||
showlegend = True | ||
) | ||
return fig | ||
|
||
if index_vals != []: | ||
fig = tools.make_subplots(rows=dim, cols=dim) | ||
|
||
# Checks index_vals for errors | ||
firstvector = matrix[0] | ||
if len(index_vals) != len(firstvector): | ||
raise exceptions.PlotlyError("The length of your index_vals " | ||
"list doesn't match the number of " | ||
"lists in your matrix. Please " | ||
"make sure both the rows of your " | ||
"matrix have the same length as " | ||
"the index_vals list.") | ||
|
||
# Define a paramter that will determine whether | ||
# or not a trace will show or hide its legend | ||
# info when drawn | ||
legend_param = 0 | ||
|
||
# Work over all permutations of list pairs | ||
for listy in matrix: | ||
for listx in matrix: | ||
|
||
# create a dictionary for index_vals | ||
unique_leg_names = {} | ||
for name in index_vals: | ||
if name not in unique_leg_names: | ||
unique_leg_names[name] = [] | ||
|
||
color_index = 0 | ||
|
||
# Fill all the rest of the names into the dictionary | ||
for name in unique_leg_names: | ||
new_listx = [] | ||
new_listy = [] | ||
|
||
for j in range(len(index_vals)): | ||
if index_vals[j] == name: | ||
new_listx.append(listx[j]) | ||
new_listy.append(listy[j]) | ||
|
||
# Generate trace with VISIBLE icon | ||
if legend_param == 1: | ||
if (listx == listy) and (diagonal == "Histogram"): | ||
trace = go.Histogram( | ||
x = new_listx, | ||
marker = dict( | ||
color = colors[color_index]), | ||
showlegend = True | ||
) | ||
else: | ||
trace = go.Scatter( | ||
x = new_listx, | ||
y = new_listy, | ||
mode = "markers", | ||
name = name, | ||
marker = dict( | ||
symbol = symbol, | ||
size = size, | ||
color = colors[color_index]), | ||
showlegend = True | ||
) | ||
|
||
# Generate trace with INVISIBLE icon | ||
if legend_param != 1: | ||
if (listx == listy) and (diagonal == "Histogram"): | ||
trace = go.Histogram( | ||
x = new_listx, | ||
marker = dict( | ||
color = colors[color_index]), | ||
showlegend = False | ||
) | ||
else: | ||
trace = go.Scatter( | ||
x = new_listx, | ||
y = new_listy, | ||
mode = "markers", | ||
name = name, | ||
marker = dict( | ||
symbol = symbol, | ||
size = size, | ||
color = colors[color_index]), | ||
showlegend = False | ||
) | ||
|
||
# Push the trace into dictionary | ||
unique_leg_names[name] = trace | ||
if color_index >= (len(colors) - 1): | ||
color_index = -1 | ||
color_index += 1 | ||
|
||
trace_list.append(unique_leg_names) | ||
legend_param += 1 | ||
|
||
# Create list of index values for the axes | ||
# eg. if dim = 3, then the indicies are [1, 2, 3] | ||
indicies = range(dim) | ||
indicies.remove(0) | ||
indicies.append(dim) | ||
j = 0 | ||
|
||
for y_index in indicies: | ||
for x_index in indicies: | ||
for name in trace_list[j]: | ||
fig.append_trace(trace_list[j][name], y_index, x_index) | ||
j += 1 | ||
|
||
# Check if length of col_names is equal to the | ||
# number of lists in matrix | ||
if len(col_names) != dim: | ||
raise exceptions.PlotlyError("Your list of variable_" | ||
"names must match the " | ||
"number of lists in your " | ||
"array. That is to say that " | ||
"both lists must have the " | ||
"same dimension.") | ||
|
||
# Insert col_names into the figure | ||
for j in range(dim): | ||
xaxis_place = "xaxis" + str(dim*dim - dim + 1 + j) | ||
fig['layout'][xaxis_place].update(title = col_names[j]) | ||
|
||
for j in range(dim): | ||
yaxis_place = "yaxis" + str(1 + dim*j) | ||
fig['layout'][yaxis_place].update(title = col_names[j]) | ||
|
||
# Set height and width if not already selected by user | ||
if (height == " "): | ||
height = 400 + 200*(dim - 1) | ||
|
||
if (width == " "): | ||
width = 400 + 200*(dim - 1) | ||
|
||
if diagonal == "Histogram": | ||
fig['layout'].update( | ||
height = height, width = width, | ||
title = title, | ||
showlegend = True, | ||
barmode = "stack") | ||
return fig | ||
|
||
if diagonal == "Scatter": | ||
fig['layout'].update( | ||
height = height, width = width, | ||
title = title, | ||
showlegend = True) | ||
return fig | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If for some reason |
||
|
||
|
||
@staticmethod | ||
def _validate_equal_length(*args): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use
None
, not" "
as default values for these things?" "
is a truthy value, which is confusing, IMO.