Skip to content

Commit ede9bd4

Browse files
ols log options checkpoint
1 parent b6dd8f9 commit ede9bd4

File tree

1 file changed

+12
-3
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+12
-3
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,19 @@ def ols(options, x, y, x_label, y_label, non_missing):
256256
import statsmodels.api as sm
257257

258258
add_constant = options.get("add_constant", True)
259-
fit_results = sm.OLS(
260-
y, sm.add_constant(x) if add_constant else x, missing="drop"
261-
).fit()
259+
log_x = options.get("log_x", False)
260+
log_y = options.get("log_y", False)
261+
262+
if log_y:
263+
y = np.log(y)
264+
if log_x:
265+
x = np.log(x)
266+
if add_constant:
267+
x = sm.add_constant(x)
268+
fit_results = sm.OLS(y, x, missing="drop").fit()
262269
y_out = fit_results.predict()
270+
if log_y:
271+
y_out = np.exp(y_out)
263272
hover_header = "<b>OLS trendline</b><br>"
264273
if len(fit_results.params) == 2:
265274
hover_header += "%s = %g * %s + %g<br>" % (

0 commit comments

Comments
 (0)