Basic Usage Example
Initializing libraries
import pandas as pd
# models
from xgboost import XGBRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
# metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from ucimlrepo import fetch_ucirepo
# dataset
from ucimlrepo import fetch_ucirepo
# XAI-DESReg
from xaidesreg import DES
Loading data
abalone = fetch_ucirepo(id=1)
X = abalone.data.features
y = abalone.data.targets
le = LabelEncoder()
X['Sex'] = le.fit_transform(X['Sex'])
Split dataset
Split the dataset into training, DES (DSEL) validation, and testing.
TEST_SIZE = 0.2
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42)
X_train, X_dsel, y_train, y_dsel = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
Models pool generation
pool_models = [
XGBRegressor(random_state=42),
RandomForestRegressor(random_state=42),
LinearRegression(),
KNeighborsRegressor(),
]
Train the models (pool):
for model in pool_models:
model.fit(X_train, y_train)
Usage of our library:
des = DES(pool_regressors=pool_models,
k=6,
knn_metric='minkowski',
metrics='mape',
threshold=0.1)
des.fit(X_dsel, y_dsel)
Testing
des.score(X_test, y_test) # based on MSE
Explainability
index = 47
X_test.iloc[index]
prediction, contribution_df, neighbors_df = des.predict_xai(X_test.iloc[[index]])
print(prediction)
The content of the contribution data frame:
Model |
Predictions |
Competence |
Weights |
|---|---|---|---|
XGBRegressor |
9.133 |
0.064 |
15.609 |
RandomForestRegressor |
8.890 |
0.078 |
12.853 |
LinearRegression |
8.913 |
0.090 |
11.141 |
KNeighborsRegressor |
9.000 |
0.072 |
13.846 |
The samples in the region of competence (neighbors_df):
Sex |
Length |
Diameter |
Height |
Whole_weight |
Shucked_weight |
Viscera_weight |
Shell_weight |
Target |
|---|---|---|---|---|---|---|---|---|
2 |
0.565 |
0.44 |
0.125 |
0.802 |
0.3595 |
0.1825 |
0.215 |
9 |
2 |
0.55 |
0.425 |
0.15 |
0.8315 |
0.411 |
0.1765 |
0.2165 |
10 |
2 |
0.56 |
0.415 |
0.13 |
0.7615 |
0.3695 |
0.17 |
0.1955 |
8 |
2 |
0.545 |
0.41 |
0.12 |
0.793 |
0.434 |
0.1405 |
0.19 |
9 |
2 |
0.56 |
0.415 |
0.145 |
0.852 |
0.43 |
0.1885 |
0.205 |
8 |
2 |
0.54 |
0.42 |
0.135 |
0.8075 |
0.3485 |
0.1795 |
0.235 |
11 |