Machine Learning Explainability


Core Idea: As machine learning models become more complex (i.e., “black boxes”), understanding why they make certain predictions is increasingly important. Machine learning explainability (or interpretability) provides tools to dissect and understand model behavior, building trust and enabling better debugging.

1. The Importance of Explainability

Understanding your model is crucial for:

2. Permutation Importance

Permutation importance is a simple yet powerful technique for measuring the importance of a feature. It works by randomly shuffling a single feature and measuring the resulting decrease in model performance. The more the performance drops, the more important the feature is.

import eli5
from eli5.sklearn import PermutationImportance

perm = PermutationImportance(my_model, random_state=1).fit(val_X, val_y)
eli5.show_weights(perm, feature_names = val_X.columns.tolist())

3. Partial Dependence Plots (PDP)

A partial dependence plot shows the marginal effect of a feature on the predicted outcome of a machine learning model. It helps to visualize the relationship between a feature and the target variable, holding all other features constant.

from sklearn.inspection import partial_dependence
from sklearn.inspection import PartialDependenceDisplay

feature_to_plot = 'feature_name'
disp = PartialDependenceDisplay.from_estimator(my_model, val_X, [feature_to_plot])

4. SHAP (SHapley Additive exPlanations)

SHAP is a game theory-based approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions. SHAP values can explain both individual predictions and the overall model structure.

4.1 Explaining Individual Predictions

import shap

explainer = shap.TreeExplainer(my_model)
shap_values = explainer.shap_values(data_for_prediction)
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values, data_for_prediction)

4.2 Explaining the Full Model

# Create a summary plot
shap.summary_plot(shap_values, val_X)

Key Takeaways:

© 2018 JIAWEI    粤ICP备18035774号