diff --git a/bigframes/ml/core.py b/bigframes/ml/core.py index ad00ed3f2c..01917fd6d8 100644 --- a/bigframes/ml/core.py +++ b/bigframes/ml/core.py @@ -134,6 +134,16 @@ def explain_predict( ), ) + def global_explain(self, options: Mapping[str, bool]) -> bpd.DataFrame: + sql = self._model_manipulation_sql_generator.ml_global_explain( + struct_options=options + ) + return ( + self._session.read_gbq(sql) + .sort_values(by="attribution", ascending=False) + .set_index("feature") + ) + def transform(self, input_data: bpd.DataFrame) -> bpd.DataFrame: return self._apply_ml_tvf( input_data, diff --git a/bigframes/ml/linear_model.py b/bigframes/ml/linear_model.py index 46c5744a42..3774a62c0c 100644 --- a/bigframes/ml/linear_model.py +++ b/bigframes/ml/linear_model.py @@ -203,6 +203,26 @@ def predict_explain( X, options={"top_k_features": top_k_features} ) + def global_explain( + self, + ) -> bpd.DataFrame: + """ + Provide explanations for an entire linear regression model. + + .. note:: + Output matches that of the BigQuery ML.GLOBAL_EXPLAIN function. + See: https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain + + Returns: + bigframes.pandas.DataFrame: + Dataframes containing feature importance values and corresponding attributions, designed to provide a global explanation of feature influence. + """ + + if not self._bqml_model: + raise RuntimeError("A model must be fitted before predict") + + return self._bqml_model.global_explain({}) + def score( self, X: utils.ArrayType, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index b662d4c22c..e89f17bcaa 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -312,6 +312,12 @@ def ml_explain_predict( return f"""SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {self._model_ref_sql()}, ({source_sql}), {struct_options_sql})""" + def ml_global_explain(self, struct_options) -> str: + """Encode ML.GLOBAL_EXPLAIN for BQML""" + struct_options_sql = self.struct_options(**struct_options) + return f"""SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {self._model_ref_sql()}, + {struct_options_sql})""" + def ml_forecast(self, struct_options: Mapping[str, Union[int, float]]) -> str: """Encode ML.FORECAST for BQML""" struct_options_sql = self.struct_options(**struct_options) diff --git a/samples/snippets/linear_regression_tutorial_test.py b/samples/snippets/linear_regression_tutorial_test.py index e4ace53a5c..8fc1c5ad61 100644 --- a/samples/snippets/linear_regression_tutorial_test.py +++ b/samples/snippets/linear_regression_tutorial_test.py @@ -92,6 +92,31 @@ def test_linear_regression(random_model_id: str) -> None: # 3 5349.603734 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 5349.603734 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.4 15.6 221.0 5000.0 MALE # 4 4637.165037 [{'feature': 'island', 'attribution': 7348.877... -5320.222128 4637.165037 0.0 Gentoo penguin (Pygoscelis papua) Biscoe 46.1 13.2 211.0 4500.0 FEMALE # [END bigquery_dataframes_bqml_linear_predict_explain] + # [START bigquery_dataframes_bqml_linear_global_explain] + # To use the `global_explain()` function, the model must be recreated with `enable_global_explain` set to `True`. + model = LinearRegression(enable_global_explain=True) + + # The model must the be fitted before it can be saved to BigQuery and then explained. + training_data = bq_df.dropna(subset=["body_mass_g"]) + X = training_data.drop(columns=["body_mass_g"]) + y = training_data[["body_mass_g"]] + model.fit(X, y) + model.to_gbq("bqml_tutorial.penguins_model", replace=True) + + # Explain the model + explain_model = model.global_explain() + + # Expected results: + # attribution + # feature + # island 5737.315921 + # species 4073.280549 + # sex 622.070896 + # flipper_length_mm 193.612051 + # culmen_depth_mm 117.084944 + # culmen_length_mm 94.366793 + # [END bigquery_dataframes_bqml_linear_global_explain] + assert explain_model is not None assert feature_columns is not None assert label_columns is not None assert model is not None diff --git a/tests/system/small/ml/conftest.py b/tests/system/small/ml/conftest.py index 0e8489c513..2b9392f523 100644 --- a/tests/system/small/ml/conftest.py +++ b/tests/system/small/ml/conftest.py @@ -84,6 +84,15 @@ def ephemera_penguins_linear_model( return bf_model +@pytest.fixture(scope="function") +def penguins_linear_model_w_global_explain( + penguins_bqml_linear_model: core.BqmlModel, +) -> linear_model.LinearRegression: + bf_model = linear_model.LinearRegression(enable_global_explain=True) + bf_model._bqml_model = penguins_bqml_linear_model + return bf_model + + @pytest.fixture(scope="session") def penguins_logistic_model( session, penguins_logistic_model_name diff --git a/tests/system/small/ml/test_linear_model.py b/tests/system/small/ml/test_linear_model.py index da9fc8e14f..8b04d55e61 100644 --- a/tests/system/small/ml/test_linear_model.py +++ b/tests/system/small/ml/test_linear_model.py @@ -228,6 +228,42 @@ def test_to_gbq_saved_linear_reg_model_scores( ) +def test_linear_reg_model_global_explain( + penguins_linear_model_w_global_explain, new_penguins_df +): + training_data = new_penguins_df.dropna(subset=["body_mass_g"]) + X = training_data.drop(columns=["body_mass_g"]) + y = training_data[["body_mass_g"]] + penguins_linear_model_w_global_explain.fit(X, y) + global_ex = penguins_linear_model_w_global_explain.global_explain() + assert global_ex.shape == (6, 1) + expected_columns = pandas.Index(["attribution"]) + pandas.testing.assert_index_equal(global_ex.columns, expected_columns) + result = global_ex.to_pandas().drop(["attribution"], axis=1).sort_index() + expected_feature = ( + pandas.DataFrame( + { + "feature": [ + "island", + "species", + "sex", + "flipper_length_mm", + "culmen_depth_mm", + "culmen_length_mm", + ] + }, + ) + .set_index("feature") + .sort_index() + ) + pandas.testing.assert_frame_equal( + result, + expected_feature, + check_exact=False, + check_index_type=False, + ) + + def test_to_gbq_replace(penguins_linear_model, table_id_unique): penguins_linear_model.to_gbq(table_id_unique, replace=True) with pytest.raises(google.api_core.exceptions.Conflict):