Skip to content

Commit f72beed

Browse files
committed
Adding Scatterplot Matrix to FigureFactory
1 parent a8a68fb commit f72beed

File tree

1 file changed

+370
-0
lines changed

1 file changed

+370
-0
lines changed

plotly/tools.py

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,376 @@ class FigureFactory(object):
14421442
more information and examples of a specific chart type.
14431443
"""
14441444

1445+
@staticmethod
1446+
def create_scatterplotmatrix(df, useindex = False, index = " ",
1447+
diagonal = "Scatter", symbol = 0, size = 6,
1448+
height = " ", width = " ", jitter = 0,
1449+
title = "Scatterplot Matrix"):
1450+
1451+
from plotly import tools
1452+
import plotly.plotly as py
1453+
import plotly.graph_objs as go
1454+
import pandas as pd
1455+
import random as r
1456+
1457+
colors = ["rgb(31, 119, 180)", "rgb(255, 127, 14)",
1458+
"rgb(44, 160, 44)", "rgb(214, 39, 40)",
1459+
"rgb(148, 103, 189)", "rgb(140, 86, 75)",
1460+
"rgb(227, 119, 194)", "rgb(127, 127, 127)",
1461+
"rgb(188, 189, 34)", "rgb(23, 190, 207)"]
1462+
1463+
matrix = []
1464+
col_names = []
1465+
index_vals = []
1466+
diagonal_choices = ["Scatter", "Histogram"]
1467+
1468+
# Check if pandas dataframe
1469+
if type(df) != pd.core.frame.DataFrame:
1470+
raise exceptions.PlotlyError("Dataframe not inputed. Please "
1471+
"use a pandas dataframe to pro"
1472+
"duce a scatterplot matrix.")
1473+
1474+
# Check if dataframe is 1 row or less.
1475+
if len(df.columns) <= 1:
1476+
raise exceptions.PlotlyError("Dataframe has only one column. To "
1477+
"use the scatterplot matrix, use at "
1478+
"least 2 columns.")
1479+
1480+
# Check that diagonal parameter is selected properly
1481+
if diagonal not in diagonal_choices:
1482+
raise exceptions.PlotlyError("Make sure diagonal is set to "
1483+
"either Scatter or Histogram.")
1484+
1485+
1486+
if useindex == True:
1487+
if index not in df:
1488+
raise exceptions.PlotlyError("Make sure you set the index "
1489+
"input variable to one of the "
1490+
"column names of your matrix.")
1491+
1492+
else:
1493+
index_vals = df[index].values.tolist()
1494+
1495+
# Make list of column names besides index
1496+
for name in df:
1497+
if name != index:
1498+
col_names.append(name)
1499+
1500+
# Populate matrix = []
1501+
for name in col_names:
1502+
matrix.append(df[name].values.tolist())
1503+
1504+
# Check if Matrix Values are either all strings or numbers
1505+
for vector in matrix:
1506+
if ((type(vector[0]) == float) or (type(vector[0]) == int)):
1507+
for entry in vector:
1508+
if not ((type(entry) == float) or (type(entry) == int)):
1509+
raise exceptions.PlotlyError("Error in data"
1510+
"frame. Make sure "
1511+
"that all entries "
1512+
"of each column are "
1513+
"either numbers or "
1514+
"strings.")
1515+
1516+
if (type(vector[0]) == str):
1517+
for entry in vector:
1518+
if (type(entry) != str):
1519+
raise exceptions.PlotlyError("Error in data"
1520+
"frame. Make sure "
1521+
"that all entries "
1522+
"of each column are "
1523+
"either numbers or "
1524+
"strings.")
1525+
1526+
# Check if index_vals are either all strings or numbers
1527+
if ((type(index_vals[0]) == float) or (type(index_vals[0]) == int)):
1528+
for entry in index_vals:
1529+
if not ((type(entry) == float) or (type(entry) == int)):
1530+
raise exceptions.PlotlyError("Error in data"
1531+
"frame. Make sure "
1532+
"that all entries "
1533+
"of each column are "
1534+
"either numbers or "
1535+
"strings.")
1536+
1537+
if (type(index_vals[0]) == str):
1538+
for entry in index_vals:
1539+
if (type(entry) != str):
1540+
raise exceptions.PlotlyError("Error in data"
1541+
"frame. Make sure "
1542+
"that all entries "
1543+
"of each column are "
1544+
"either numbers or "
1545+
"strings.")
1546+
1547+
if useindex == False:
1548+
# Make list of column names
1549+
for name in df:
1550+
col_names.append(name)
1551+
1552+
# Populate matrix = [] with dataframe columns
1553+
for name in col_names:
1554+
matrix.append(df[name].values.tolist())
1555+
1556+
1557+
# Check if values in each column are either
1558+
# all strings or all numbers
1559+
1560+
# Matrix Check
1561+
for vector in matrix:
1562+
if ((type(vector[0]) == float) or (type(vector[0]) == int)):
1563+
for entry in vector:
1564+
if not ((type(entry) == float) or (type(entry) == int)):
1565+
raise exceptions.PlotlyError("Error in data"
1566+
"frame. Make sure "
1567+
"that all entries "
1568+
"of each column are "
1569+
"either numbers or "
1570+
"strings.")
1571+
1572+
if (type(vector[0]) == str):
1573+
for entry in vector:
1574+
if (type(entry) != str):
1575+
raise exceptions.PlotlyError("Error in dataframe. "
1576+
"Make sure that all "
1577+
"entries of each col"
1578+
"umn are either "
1579+
"numbers or strings.")
1580+
1581+
# Main Code
1582+
dim = len(matrix)
1583+
trace_list = []
1584+
1585+
if index_vals == []:
1586+
fig = tools.make_subplots(rows=dim, cols=dim)
1587+
1588+
# Insert traces into trace_list
1589+
for listy in matrix:
1590+
for listx in matrix:
1591+
if (listx == listy) and (diagonal == "Histogram"):
1592+
trace = go.Histogram(
1593+
x = listx,
1594+
showlegend = False
1595+
)
1596+
else:
1597+
# Add Jitter
1598+
if (jitter < 0.0) or (jitter > 1.0):
1599+
raise exceptions.PlotlyError("Jitter must lie "
1600+
"between 0 and 1.0 "
1601+
"inclusive.")
1602+
1603+
if type(listx[0]) != str:
1604+
for j in range(len(listx)):
1605+
listx[j] = listx[j] + jitter*r.uniform(-1,1)
1606+
1607+
if type(listy[0]) != str:
1608+
for j in range(len(listy)):
1609+
listy[j] = listy[j] + jitter*r.uniform(-1,1)
1610+
1611+
trace = go.Scatter(
1612+
x = listx,
1613+
y = listy,
1614+
mode = "markers",
1615+
marker = dict(
1616+
symbol = symbol,
1617+
size = size),
1618+
showlegend = False,
1619+
)
1620+
trace_list.append(trace)
1621+
1622+
# Create list of index values for the axes
1623+
# eg. if dim = 3, then the indicies are [1, 2, 3]
1624+
indicies = range(dim)
1625+
indicies.remove(0)
1626+
indicies.append(dim)
1627+
j = 0
1628+
1629+
for y_index in indicies:
1630+
for x_index in indicies:
1631+
fig.append_trace(trace_list[j], y_index, x_index)
1632+
j += 1
1633+
1634+
# Check if length of col_names and array match
1635+
if len(col_names) != dim:
1636+
raise exceptions.PlotlyError("The length of your variable_"
1637+
"names list doesn't match the "
1638+
"number of lists in your "
1639+
"matrix. This means that both "
1640+
"col_names and matrix "
1641+
"must have the same "
1642+
"dimension.")
1643+
1644+
# Insert col_names into the figure
1645+
for j in range(dim):
1646+
xaxis_place = "xaxis" + str(dim*dim - dim + 1 + j)
1647+
fig['layout'][xaxis_place].update(title = col_names[j])
1648+
1649+
for j in range(dim):
1650+
yaxis_place = "yaxis" + str(1 + dim*j)
1651+
fig['layout'][yaxis_place].update(title = col_names[j])
1652+
1653+
# Set height and width if not already selected by user
1654+
if (height == " "):
1655+
height = 400 + 200*(dim - 1)
1656+
1657+
if (width == " "):
1658+
width = 400 + 200*(dim - 1)
1659+
1660+
fig['layout'].update(
1661+
height = height, width = width,
1662+
title = title,
1663+
showlegend = True
1664+
)
1665+
return fig
1666+
1667+
if index_vals != []:
1668+
fig = tools.make_subplots(rows=dim, cols=dim)
1669+
1670+
# Checks index_vals for errors
1671+
firstvector = matrix[0]
1672+
if len(index_vals) != len(firstvector):
1673+
raise exceptions.PlotlyError("The length of your index_vals "
1674+
"list doesn't match the number of "
1675+
"lists in your matrix. Please "
1676+
"make sure both the rows of your "
1677+
"matrix have the same length as "
1678+
"the index_vals list.")
1679+
1680+
# Define a paramter that will determine whether
1681+
# or not a trace will show or hide its legend
1682+
# info when drawn
1683+
legend_param = 0
1684+
1685+
# Work over all permutations of list pairs
1686+
for listy in matrix:
1687+
for listx in matrix:
1688+
1689+
# create a dictionary for index_vals
1690+
unique_leg_names = {}
1691+
for name in index_vals:
1692+
if name not in unique_leg_names:
1693+
unique_leg_names[name] = []
1694+
1695+
color_index = 0
1696+
1697+
# Fill all the rest of the names into the dictionary
1698+
for name in unique_leg_names:
1699+
new_listx = []
1700+
new_listy = []
1701+
1702+
for j in range(len(index_vals)):
1703+
if index_vals[j] == name:
1704+
new_listx.append(listx[j])
1705+
new_listy.append(listy[j])
1706+
1707+
# Generate trace with VISIBLE icon
1708+
if legend_param == 1:
1709+
if (listx == listy) and (diagonal == "Histogram"):
1710+
trace = go.Histogram(
1711+
x = new_listx,
1712+
marker = dict(
1713+
color = colors[color_index]),
1714+
showlegend = True
1715+
)
1716+
else:
1717+
trace = go.Scatter(
1718+
x = new_listx,
1719+
y = new_listy,
1720+
mode = "markers",
1721+
name = name,
1722+
marker = dict(
1723+
symbol = symbol,
1724+
size = size,
1725+
color = colors[color_index]),
1726+
showlegend = True
1727+
)
1728+
1729+
# Generate trace with INVISIBLE icon
1730+
if legend_param != 1:
1731+
if (listx == listy) and (diagonal == "Histogram"):
1732+
trace = go.Histogram(
1733+
x = new_listx,
1734+
marker = dict(
1735+
color = colors[color_index]),
1736+
showlegend = False
1737+
)
1738+
else:
1739+
trace = go.Scatter(
1740+
x = new_listx,
1741+
y = new_listy,
1742+
mode = "markers",
1743+
name = name,
1744+
marker = dict(
1745+
symbol = symbol,
1746+
size = size,
1747+
color = colors[color_index]),
1748+
showlegend = False
1749+
)
1750+
1751+
# Push the trace into dictionary
1752+
unique_leg_names[name] = trace
1753+
if color_index >= (len(colors) - 1):
1754+
color_index = -1
1755+
color_index += 1
1756+
1757+
trace_list.append(unique_leg_names)
1758+
legend_param += 1
1759+
1760+
# Create list of index values for the axes
1761+
# eg. if dim = 3, then the indicies are [1, 2, 3]
1762+
indicies = range(dim)
1763+
indicies.remove(0)
1764+
indicies.append(dim)
1765+
j = 0
1766+
1767+
for y_index in indicies:
1768+
for x_index in indicies:
1769+
for name in trace_list[j]:
1770+
fig.append_trace(trace_list[j][name], y_index, x_index)
1771+
j += 1
1772+
1773+
# Check if length of col_names is equal to the
1774+
# number of lists in matrix
1775+
if len(col_names) != dim:
1776+
raise exceptions.PlotlyError("Your list of variable_"
1777+
"names must match the "
1778+
"number of lists in your "
1779+
"array. That is to say that "
1780+
"both lists must have the "
1781+
"same dimension.")
1782+
1783+
# Insert col_names into the figure
1784+
for j in range(dim):
1785+
xaxis_place = "xaxis" + str(dim*dim - dim + 1 + j)
1786+
fig['layout'][xaxis_place].update(title = col_names[j])
1787+
1788+
for j in range(dim):
1789+
yaxis_place = "yaxis" + str(1 + dim*j)
1790+
fig['layout'][yaxis_place].update(title = col_names[j])
1791+
1792+
# Set height and width if not already selected by user
1793+
if (height == " "):
1794+
height = 400 + 200*(dim - 1)
1795+
1796+
if (width == " "):
1797+
width = 400 + 200*(dim - 1)
1798+
1799+
if diagonal == "Histogram":
1800+
fig['layout'].update(
1801+
height = height, width = width,
1802+
title = title,
1803+
showlegend = True,
1804+
barmode = "stack")
1805+
return fig
1806+
1807+
if diagonal == "Scatter":
1808+
fig['layout'].update(
1809+
height = height, width = width,
1810+
title = title,
1811+
showlegend = True)
1812+
return fig
1813+
1814+
14451815
@staticmethod
14461816
def _validate_equal_length(*args):
14471817
"""

0 commit comments

Comments
 (0)