diff --git a/plotly/tests/test_optional/test_figure_factory.py b/plotly/tests/test_optional/test_figure_factory.py index 133ae568414..f39c5c8f04e 100644 --- a/plotly/tests/test_optional/test_figure_factory.py +++ b/plotly/tests/test_optional/test_figure_factory.py @@ -8,6 +8,7 @@ from nose.tools import raises import numpy as np +from scipy.spatial import Delaunay import pandas as pd @@ -532,6 +533,182 @@ def test_dendrogram_colorscale(self): self.assert_dict_equal(dendro['data'][2], expected_dendro['data'][2]) +class TestTrisurf(NumpyTestUtilsMixin, TestCase): + + def test_vmin_and_vmax(self): + + # check if vmin is greater than or equal to vmax + u = np.linspace(0, 2, 2) + v = np.linspace(0, 2, 2) + u, v = np.meshgrid(u, v) + u = u.flatten() + v = v.flatten() + + x = u + y = v + z = u*v + + points2D = np.vstack([u, v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + pattern = ( + "Incorrect relation between vmin and vmax. The vmin value cannot " + "be bigger than or equal to the value of vmax." + ) + + self.assertRaisesRegexp(PlotlyError, pattern, + tls.FigureFactory.create_trisurf, + x, y, z, simplices) + + def test_valid_colormap(self): + + # create data for trisurf plot + u = np.linspace(-np.pi, np.pi, 3) + v = np.linspace(-np.pi, np.pi, 3) + u, v = np.meshgrid(u, v) + u = u.flatten() + v = v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u, v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # check that a valid plotly colorscale name is entered + self.assertRaises(PlotlyError, tls.FigureFactory.create_trisurf, + x, y, z, simplices, colormap='foo') + + # check that colormap is a list, if not a string + + pattern1 = ( + "If 'colormap' is a list, then its items must be tripets of the " + "form a,b,c or 'rgbx,y,z' where a,b,c are between 0 and 1 " + "inclusive and x,y,z are between 0 and 255 inclusive." + ) + + self.assertRaisesRegexp(PlotlyError, pattern1, + tls.FigureFactory.create_trisurf, + x, y, z, simplices, colormap=3) + + # check: if colormap is a list of rgb color strings, make sure the + # entries of each color are no greater than 255.0 + + pattern2 = ( + "Whoops! The elements in your rgb colormap tuples " + "cannot exceed 255.0." + ) + + self.assertRaisesRegexp(PlotlyError, pattern2, + tls.FigureFactory.create_trisurf, + x, y, z, simplices, + colormap=['rgb(1, 2, 3)', 'rgb(4, 5, 600)']) + + # check: if colormap is a list of tuple colors, make sure the entries + # of each tuple are no greater than 1.0 + + pattern3 = ( + "Whoops! The elements in your rgb colormap tuples " + "cannot exceed 1.0." + ) + + self.assertRaisesRegexp(PlotlyError, pattern3, + tls.FigureFactory.create_trisurf, + x, y, z, simplices, + colormap=[(0.2, 0.4, 0.6), (0.8, 1.0, 1.2)]) + + def test_trisurf_all_args(self): + + # check if trisurf plot matches with expected output + u = np.linspace(-1, 1, 3) + v = np.linspace(-1, 1, 3) + u, v = np.meshgrid(u, v) + u = u.flatten() + v = v.flatten() + + x = u + y = v + z = u*v + + points2D = np.vstack([u, v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + test_trisurf_plot = tls.FigureFactory.create_trisurf( + x, y, z, simplices + ) + + exp_trisurf_plot = { + 'data': [ + { + 'facecolor': ['rgb(143.0, 123.0, 97.000000000000014)', + 'rgb(255.0, 127.0, 14.000000000000007)', + 'rgb(143.0, 123.0, 97.000000000000014)', + 'rgb(31.0, 119.0, 180.0)', + 'rgb(143.0, 123.0, 97.000000000000014)', + 'rgb(31.0, 119.0, 180.0)', + 'rgb(143.0, 123.0, 97.000000000000014)', + 'rgb(255.0, 127.0, 14.000000000000007)'], + 'i': [3, 1, 1, 5, 7, 3, 5, 7], + 'j': [1, 3, 5, 1, 3, 7, 7, 5], + 'k': [4, 0, 4, 2, 4, 6, 4, 8], + 'name': '', + 'type': 'mesh3d', + 'x': np.array([-1., 0., 1., -1., 0., 1., -1., 0., 1.]), + 'y': np.array([-1., -1., -1., 0., 0., 0., 1., 1., 1.]), + 'z': np.array([ 1., -0., -1., -0., 0., 0., -1., 0., 1.]) + }, + { + 'line': {'color': 'rgb(50, 50, 50)', 'width': 1.5}, + 'mode': 'lines', + 'type': 'scatter3d', + 'x': [-1.0, 0.0, 0.0, -1.0, None, 0.0, -1.0, -1.0, 0.0, None, + 0.0, 1.0, 0.0, 0.0, None, 1.0, 0.0, 1.0, 1.0, None, 0.0, + -1.0, 0.0, 0.0, None, -1.0, 0.0, -1.0, -1.0, None, 1.0, + 0.0, 0.0, 1.0, None, 0.0, 1.0, 1.0, 0.0, None], + 'y': [0.0, -1.0, 0.0, 0.0, None, -1.0, 0.0, -1.0, -1.0, None, + -1.0, 0.0, 0.0, -1.0, None, 0.0, -1.0, -1.0, 0.0, None, + 1.0, 0.0, 0.0, 1.0, None, 0.0, 1.0, 1.0, 0.0, None, 0.0, + 1.0, 0.0, 0.0, None, 1.0, 0.0, 1.0, 1.0, None], + 'z': [-0.0, -0.0, 0.0, -0.0, None, -0.0, -0.0, 1.0, -0.0, + None, -0.0, 0.0, 0.0, -0.0, None, 0.0, -0.0, -1.0, 0.0, + None, 0.0, -0.0, 0.0, 0.0, None, -0.0, 0.0, -1.0, -0.0, + None, 0.0, 0.0, 0.0, 0.0, None, 0.0, 0.0, 1.0, 0.0, None] + } + ], + 'layout': { + 'height': 800, + 'scene': {'aspectratio': {'x': 1, 'y': 1, 'z': 1}, + 'xaxis': {'backgroundcolor': 'rgb(230, 230, 230)', + 'gridcolor': 'rgb(255, 255, 255)', + 'showbackground': True, + 'zerolinecolor': 'rgb(255, 255, 255)'}, + 'yaxis': {'backgroundcolor': 'rgb(230, 230, 230)', + 'gridcolor': 'rgb(255, 255, 255)', + 'showbackground': True, + 'zerolinecolor': 'rgb(255, 255, 255)'}, + 'zaxis': {'backgroundcolor': 'rgb(230, 230, 230)', + 'gridcolor': 'rgb(255, 255, 255)', + 'showbackground': True, + 'zerolinecolor': 'rgb(255, 255, 255)'}}, + 'title': 'Trisurf Plot', + 'width': 800 + } + } + + self.assert_dict_equal(test_trisurf_plot['layout'], + exp_trisurf_plot['layout']) + + self.assert_dict_equal(test_trisurf_plot['data'][0], + exp_trisurf_plot['data'][0]) + + self.assert_dict_equal(test_trisurf_plot['data'][1], + exp_trisurf_plot['data'][1]) + + class TestScatterPlotMatrix(NumpyTestUtilsMixin, TestCase): def test_dataframe_input(self): @@ -703,7 +880,7 @@ def test_scatter_plot_matrix(self): columns=['Numbers', 'Fruit']) test_scatter_plot_matrix = tls.FigureFactory.create_scatterplotmatrix( - df, diag='scatter', height=1000, width=1000, size=13, + df=df, diag='scatter', height=1000, width=1000, size=13, title='Scatterplot Matrix', use_theme=False ) diff --git a/plotly/tools.py b/plotly/tools.py index 935909f75c7..0a109d08c37 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -28,6 +28,7 @@ 'rgb(227, 119, 194)', 'rgb(127, 127, 127)', 'rgb(188, 189, 34)', 'rgb(23, 190, 207)'] + # Warning format def warning_on_one_line(message, category, filename, lineno, file=None, line=None): @@ -1449,6 +1450,431 @@ class FigureFactory(object): more information and examples of a specific chart type. """ + @staticmethod + def _find_intermediate_color(lowcolor, highcolor, intermed): + """ + Returns the color at a given distance between two colors + + This function takes two color tuples, where each element is between 0 + and 1, along with a value 0 < intermed < 1 and returns a color that is + intermed-percent from lowcolor to highcolor + + """ + diff_0 = float(highcolor[0] - lowcolor[0]) + diff_1 = float(highcolor[1] - lowcolor[1]) + diff_2 = float(highcolor[2] - lowcolor[2]) + + new_tuple = (lowcolor[0] + intermed*diff_0, + lowcolor[1] + intermed*diff_1, + lowcolor[2] + intermed*diff_2) + + return new_tuple + + @staticmethod + def _unconvert_from_RGB_255(colors): + """ + Return a tuple where each element gets divided by 255 + + Takes a list of color tuples where each element is between 0 and 255 + and returns the same list where each tuple element is normalized to be + between 0 and 1 + + """ + un_rgb_colors = [] + for color in colors: + un_rgb_color = (color[0]/(255.0), + color[1]/(255.0), + color[2]/(255.0)) + + un_rgb_colors.append(un_rgb_color) + + return un_rgb_colors + + @staticmethod + def _map_z2color(zval, colormap, vmin, vmax): + """ + Returns the color corresponding zval's place between vmin and vmax + + This function takes a z value (zval) along with a colormap and a + minimum (vmin) and maximum (vmax) range of possible z values for the + given parametrized surface. It returns an rgb color based on the + relative position of zval between vmin and vmax + + """ + if vmin >= vmax: + raise exceptions.PlotlyError("Incorrect relation between vmin " + "and vmax. The vmin value cannot be " + "bigger than or equal to the value " + "of vmax.") + # find distance t of zval from vmin to vmax where the distance + # is normalized to be between 0 and 1 + t = (zval - vmin)/float((vmax - vmin)) + t_color = FigureFactory._find_intermediate_color(colormap[0], + colormap[1], + t) + t_color = (t_color[0]*255.0, t_color[1]*255.0, t_color[2]*255.0) + labelled_color = 'rgb{}'.format(t_color) + + return labelled_color + + @staticmethod + def _tri_indices(simplices): + """ + Returns a triplet of lists containing simplex coordinates + """ + return ([triplet[c] for triplet in simplices] for c in range(3)) + + @staticmethod + def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, + plot_edges=None, x_edge=None, y_edge=None, z_edge=None): + """ + Refer to FigureFactory.create_trisurf() for docstring + """ + # numpy import check + if _numpy_imported is False: + raise ImportError("FigureFactory._trisurf() requires " + "numpy imported.") + import numpy as np + from plotly.graph_objs import graph_objs + points3D = np.vstack((x, y, z)).T + + # vertices of the surface triangles + tri_vertices = list(map(lambda index: points3D[index], simplices)) + + if not dist_func: + # mean values of z-coordinates of triangle vertices + mean_dists = [np.mean(tri[:, 2]) for tri in tri_vertices] + else: + # apply user inputted function to calculate + # custom coloring for triangle vertices + mean_dists = [] + + for triangle in tri_vertices: + dists = [] + for vertex in triangle: + dist = dist_func(vertex[0], vertex[1], vertex[2]) + dists.append(dist) + + mean_dists.append(np.mean(dists)) + + min_mean_dists = np.min(mean_dists) + max_mean_dists = np.max(mean_dists) + facecolor = ([FigureFactory._map_z2color(zz, colormap, min_mean_dists, + max_mean_dists) for zz in mean_dists]) + ii, jj, kk = FigureFactory._tri_indices(simplices) + + triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, + i=ii, j=jj, k=kk, name='') + + if plot_edges is None: # the triangle sides are not plotted + return graph_objs.Data([triangles]) + + # define the lists x_edge, y_edge and z_edge, of x, y, resp z + # coordinates of edge end points for each triangle + # None separates data corresponding to two consecutive triangles + lists_coord = ([[[T[k % 3][c] for k in range(4)]+[None] + for T in tri_vertices] for c in range(3)]) + if x_edge is None: + x_edge = [] + for array in lists_coord[0]: + for item in array: + x_edge.append(item) + + if y_edge is None: + y_edge = [] + for array in lists_coord[1]: + for item in array: + y_edge.append(item) + + if z_edge is None: + z_edge = [] + for array in lists_coord[2]: + for item in array: + z_edge.append(item) + + # define the lines for plotting + lines = graph_objs.Scatter3d( + x=x_edge, y=y_edge, z=z_edge, mode='lines', + line=graph_objs.Line(color='rgb(50, 50, 50)', + width=1.5) + ) + + return graph_objs.Data([triangles, lines]) + + @staticmethod + def create_trisurf(x, y, z, simplices, colormap=None, + dist_func=None, title='Trisurf Plot', + showbackground=True, + backgroundcolor='rgb(230, 230, 230)', + gridcolor='rgb(255, 255, 255)', + zerolinecolor='rgb(255, 255, 255)', + height=800, width=800, + aspectratio=dict(x=1, y=1, z=1)): + """ + Returns figure for a triangulated surface plot + + :param (array) x: data values of x in a 1D array + :param (array) y: data values of y in a 1D array + :param (array) z: data values of z in a 1D array + :param (array) simplices: an array of shape (ntri, 3) where ntri is + the number of triangles in the triangularization. Each row of the + array contains the indicies of the verticies of each triangle. + :param (str|list) colormap: either a plotly scale name, or a list + containing 2 triplets. These triplets must be of the form (a,b,c) + or 'rgb(x,y,z)' where a,b,c belong to the interval [0,1] and x,y,z + belong to [0,255] + :param (function) dist_func: The function that determines how the + coloring of the surface changes. It takes 3 arguments x, y, z and + must return a formula of these variables which can include numpy + functions (eg. np.sqrt). If set to None, color will only depend on + the z axis. + :param (str) title: title of the plot + :param (bool) showbackground: makes background in plot visible + :param (str) backgroundcolor: color of background. Takes a string of + the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive. + :param (str) gridcolor: color of the gridlines besides the axes. Takes + a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 + inclusive. + :param (str) zerolinecolor: color of the axes. Takes a string of the + form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive. + :param (int|float) height: the height of the plot (in pixels) + :param (int|float) width: the width of the plot (in pixels) + :param (dict) aspectratio: a dictionary of the aspect ratio values for + the x, y and z axes. 'x', 'y' and 'z' take (int|float) values. + + Example 1: Sphere + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = np.sin(v)*np.cos(u) + y = np.sin(v)*np.sin(u) + z = np.cos(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = FF.create_trisurf(x=x, y=y, z=z, + colormap="Blues", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='Trisurf Plot - Sphere') + ``` + + Example 2: Torus + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, 2*np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = (3 + (np.cos(v)))*np.cos(u) + y = (3 + (np.cos(v)))*np.sin(u) + z = np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = FF.create_trisurf(x=x, y=y, z=z, + colormap="Portland", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='Trisurf Plot - Torus') + ``` + + Example 3: Mobius Band + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 24) + v = np.linspace(-1, 1, 8) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + tp = 1 + 0.5*v*np.cos(u/2.) + x = tp*np.cos(u) + y = tp*np.sin(u) + z = 0.5*v*np.sin(u/2.) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = FF.create_trisurf(x=x, y=y, z=z, + colormap=[(0.2, 0.4, 0.6),(1, 1, 1)], + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='Trisurf Plot - Mobius Band') + ``` + + Example 4: Using a Custom Colormap Function with Light Cone + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.tools import FigureFactory as FF + from plotly.graph_objs import graph_objs + + # Make data for plot + u=np.linspace(-np.pi, np.pi, 30) + v=np.linspace(-np.pi, np.pi, 30) + u,v=np.meshgrid(u,v) + u=u.flatten() + v=v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Define distance function + def dist_origin(x, y, z): + return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) + + # Create a figure + fig1 = FF.create_trisurf(x=x, y=y, z=z, + colormap="Blues", + simplices=simplices, + dist_func=dist_origin) + # Plot the data + py.iplot(fig1, filename='Trisurf Plot - Custom Coloring') + ``` + """ + from plotly.graph_objs import graph_objs + plotly_scales = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], + 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], + 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], + 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], + 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], + 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], + 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], + 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], + 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], + 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], + 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], + 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], + 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], + 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], + 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} + + # Validate colormap + if colormap is None: + colormap = [DEFAULT_PLOTLY_COLORS[0], + DEFAULT_PLOTLY_COLORS[1]] + colormap = FigureFactory._unlabel_rgb(colormap) + colormap = FigureFactory._unconvert_from_RGB_255(colormap) + + if isinstance(colormap, str): + if colormap not in plotly_scales: + scale_keys = list(plotly_scales.keys()) + raise exceptions.PlotlyError("You must pick a valid " + "plotly colorscale " + "name from " + "{}".format(scale_keys)) + + colormap = [plotly_scales[colormap][0], + plotly_scales[colormap][1]] + colormap = FigureFactory._unlabel_rgb(colormap) + colormap = FigureFactory._unconvert_from_RGB_255(colormap) + + else: + if not isinstance(colormap, list): + raise exceptions.PlotlyError("If 'colormap' is a list, then " + "its items must be tripets of " + "the form a,b,c or 'rgbx,y,z' " + "where a,b,c are between 0 and " + "1 inclusive and x,y,z are " + "between 0 and 255 inclusive.") + if 'rgb' in colormap[0]: + colormap = FigureFactory._unlabel_rgb(colormap) + for color in colormap: + for index in range(3): + if color[index] > 255.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in your " + "rgb colormap " + "tuples cannot " + "exceed 255.0.") + colormap = FigureFactory._unconvert_from_RGB_255(colormap) + + if isinstance(colormap[0], tuple): + for color in colormap: + for index in range(3): + if color[index] > 1.0: + raise exceptions.PlotlyError("Whoops! The " + "elements in your " + "rgb colormap " + "tuples cannot " + "exceed 1.0.") + + data1 = FigureFactory._trisurf(x, y, z, simplices, + dist_func=dist_func, + colormap=colormap, + plot_edges=True) + axis = dict( + showbackground=showbackground, + backgroundcolor=backgroundcolor, + gridcolor=gridcolor, + zerolinecolor=zerolinecolor, + ) + layout = graph_objs.Layout( + title=title, + width=width, + height=height, + scene=graph_objs.Scene( + xaxis=graph_objs.XAxis(axis), + yaxis=graph_objs.YAxis(axis), + zaxis=graph_objs.ZAxis(axis), + aspectratio=dict( + x=aspectratio['x'], + y=aspectratio['y'], + z=aspectratio['z']), + ) + ) + return graph_objs.Figure(data=data1, layout=layout) + @staticmethod def _scatterplot(dataframe, headers, diag, size,