Skip to content

feat: add Linear_Regression.global_explain() #1446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
4d3e4a3
feat: add Linear_Regression.global_explain()
rey-esp Mar 3, 2025
87db2b7
remove class_level_explain param
rey-esp Mar 4, 2025
0b73343
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 5, 2025
1d0c69b
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 5, 2025
6d563d5
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
024a989
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 11, 2025
813fbd7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
e99fdd7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
82a234a
working global_explain()
rey-esp Mar 11, 2025
6dc4fac
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 11, 2025
d583a37
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 11, 2025
ed73f88
begin adding tests
rey-esp Mar 11, 2025
5b7a4b7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
47b9862
update snippet
rey-esp Mar 12, 2025
606a7b8
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 12, 2025
5fe306f
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
b0e8a5d
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
0664d6a
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
eb33e09
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
31d741d
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 12, 2025
7046dc3
complete snippet
rey-esp Mar 12, 2025
b0b9552
failing, near complete linear model test
rey-esp Mar 12, 2025
3b0526e
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 13, 2025
1ad5208
passing system test
rey-esp Mar 14, 2025
7e24b4c
Merge branch 'b338872698-global-explain' of github.com:googleapis/pyt…
rey-esp Mar 14, 2025
c754816
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 14, 2025
d2d8b0c
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 17, 2025
a600539
Update core.py - set index to have sorted by feature
rey-esp Mar 17, 2025
7fc0cc6
Update test_linear_model.py - remove set/set index
rey-esp Mar 17, 2025
57c3d4a
Update linear_model.py - fix doc section
rey-esp Mar 17, 2025
c2c0837
Update conftest.py - rename penguins w global explain
rey-esp Mar 17, 2025
b2f8c9f
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 17, 2025
3a0c6b9
Update linear_model.py - complete doc
rey-esp Mar 17, 2025
5dac41d
lint
rey-esp Mar 17, 2025
e5f4aad
passing test and fixed expected results
rey-esp Mar 18, 2025
cd321e6
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
26c6a74
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
f47f5b7
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 18, 2025
7bcade0
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 19, 2025
1379a56
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 24, 2025
0bb9186
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 24, 2025
9a2b8e4
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
c8fec3a
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
562d0b8
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
083af6c
Merge branch 'main' into b338872698-global-explain
rey-esp Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bigframes/ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ def explain_predict(
),
)

def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame:
sql = self._model_manipulation_sql_generator.ml_global_explain(
struct_options=options
)
return (
self._session.read_gbq(sql)
.sort_values(by="attribution", ascending=False)
.set_index("feature")
)

def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame:
return self._apply_ml_tvf(
input_data,
Expand Down
20 changes: 20 additions & 0 deletions bigframes/ml/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,26 @@ def predict_explain(
X, options={"top_k_features": top_k_features}
)

def global_explain(
self,
) -> bpd.DataFrame:
"""
Provide explanations for an entire linear regression model.

.. note::
Output matches that of the BigQuery ML.GLOBAL_EXPLAIN function.
See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain

Returns:
bigframes.pandas.DataFrame:
Dataframes containing feature importance values and corresponding attributions, designed to provide a global explanation of feature influence.
"""

if not self._bqml_model:
raise RuntimeError("A model must be fitted before predict")

return self._bqml_model.global_explain({})

def score(
self,
X: utils.ArrayType,
Expand Down
6 changes: 6 additions & 0 deletions bigframes/ml/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def ml_explain_predict(
return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()},
({source_sql}), {struct_options_sql})"""

def ml_global_explain(self, struct_options) -> str:
"""Encode ML.GLOBAL_EXPLAIN for BQML"""
struct_options_sql = self.struct_options(**struct_options)
return f"""SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {self._model_ref_sql()},
{struct_options_sql})"""

def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str:
"""Encode ML.FORECAST for BQML"""
struct_options_sql = self.struct_options(**struct_options)
Expand Down
25 changes: 25 additions & 0 deletions samples/snippets/linear_regression_tutorial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ def test_linear_regression(random_model_id: str) -> None:
# 3 5349.603734 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 5349.603734 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 15.6 221.0 5000.0 MALE
# 4 4637.165037 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 4637.165037 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 13.2 211.0 4500.0 FEMALE
# [END bigquery_dataframes_bqml_linear_predict_explain]
# [START bigquery_dataframes_bqml_linear_global_explain]
# To use the `global_explain()` function, the model must be recreated with `enable_global_explain` set to `True`.
model = LinearRegression(enable_global_explain=True)

# The model must the be fitted before it can be saved to BigQuery and then explained.
training_data = bq_df.dropna(subset=["body_mass_g"])
X = training_data.drop(columns=["body_mass_g"])
y = training_data[["body_mass_g"]]
model.fit(X, y)
model.to_gbq("bqml_tutorial.penguins_model", replace=True)

# Explain the model
explain_model = model.global_explain()

# Expected results:
# attribution
# feature
# island 5737.315921
# species 4073.280549
# sex 622.070896
# flipper_length_mm 193.612051
# culmen_depth_mm 117.084944
# culmen_length_mm 94.366793
# [END bigquery_dataframes_bqml_linear_global_explain]
assert explain_model is not None
assert feature_columns is not None
assert label_columns is not None
assert model is not None
Expand Down
9 changes: 9 additions & 0 deletions tests/system/small/ml/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def ephemera_penguins_linear_model(
return bf_model


@pytest.fixture(scope="function")
def penguins_linear_model_w_global_explain(
penguins_bqml_linear_model: core.BqmlModel,
) -> linear_model.LinearRegression:
bf_model = linear_model.LinearRegression(enable_global_explain=True)
bf_model._bqml_model = penguins_bqml_linear_model
return bf_model


@pytest.fixture(scope="session")
def penguins_logistic_model(
session, penguins_logistic_model_name
Expand Down
36 changes: 36 additions & 0 deletions tests/system/small/ml/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,42 @@ def test_to_gbq_saved_linear_reg_model_scores(
)


def test_linear_reg_model_global_explain(
penguins_linear_model_w_global_explain, new_penguins_df
):
training_data = new_penguins_df.dropna(subset=["body_mass_g"])
X = training_data.drop(columns=["body_mass_g"])
y = training_data[["body_mass_g"]]
penguins_linear_model_w_global_explain.fit(X, y)
global_ex = penguins_linear_model_w_global_explain.global_explain()
assert global_ex.shape == (6, 1)
expected_columns = pandas.Index(["attribution"])
pandas.testing.assert_index_equal(global_ex.columns, expected_columns)
result = global_ex.to_pandas().drop(["attribution"], axis=1).sort_index()
expected_feature = (
pandas.DataFrame(
{
"feature": [
"island",
"species",
"sex",
"flipper_length_mm",
"culmen_depth_mm",
"culmen_length_mm",
]
},
)
.set_index("feature")
.sort_index()
)
pandas.testing.assert_frame_equal(
result,
expected_feature,
check_exact=False,
check_index_type=False,
)


def test_to_gbq_replace(penguins_linear_model, table_id_unique):
penguins_linear_model.to_gbq(table_id_unique, replace=True)
with pytest.raises(google.api_core.exceptions.Conflict):
Expand Down