-
Notifications
You must be signed in to change notification settings - Fork 133
Open
Description
I am using a dataset to compute feature importance using permutation. Have checked results with R implementation, I am getting non zero var importance. What could be the reason? Here is my code
from rfpimp import *
from sklearn.ensemble.forest import _generate_unsampled_indices
# TODO: add arg for subsample size to compute oob score
def oob_classifier_accuracy(rf, X_train, y_train):
X = X_train.values
y = y_train.values
n_samples = len(X)
n_classes = len(np.unique(y))
predictions = np.zeros((n_samples, n_classes))
for tree in rf.estimators_:
unsampled_indices = _generate_unsampled_indices(tree.random_state, n_samples)
tree_preds = tree.predict_proba(X[unsampled_indices, :])
predictions[unsampled_indices] += tree_preds
predicted_class_indexes = np.argmax(predictions, axis=1)
predicted_classes = [rf.classes_[i] for i in predicted_class_indexes]
oob_score = np.mean(y == predicted_classes)
return oob_score
def permutation_importances(rf, X_train, y_train, metric):
"""
Return importances from pre-fit rf; metric is function
that measures accuracy or R^2 or similar. This function
works for regressors and classifiers.
"""
baseline = metric(rf, X_train, y_train)
imp = []
for col in X_train.columns:
save = X_train[col].copy()
X_train[col] = np.random.permutation(X_train[col])
m = metric(rf, X_train, y_train)
X_train[col] = save
imp.append(baseline - m)
return np.array(imp)
rf = clone(base_rf)
rf.fit(X_train, y_train)
oob = oob_classifier_accuracy(rf, X_train, y_train)
print("oob accuracy",oob)
imp = permutation_importances(rf, X_train, y_train,
oob_classifier_accuracy)
imp
Gives an output of:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
I also computed the oob_classiifer_accuracy()
by permuting all the variables, the accuracy reported doesn't change at all. The event rate is data is rather low around 5%.
Metadata
Metadata
Assignees
Labels
No labels