Feature names and importances (#621)

* feature names and importances

* None check

* StackingClassifier has no feature_importances_

* StackingClassifier has no feature_names_in_
This commit is contained in:
Chi Wang 2022-07-10 12:25:59 -07:00 committed by GitHub
parent 59bd3c1979
commit e14e909af9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1778 additions and 1015 deletions

View File

@ -856,6 +856,20 @@ class AutoML(BaseEstimator):
def n_features_in_(self):
return self._trained_estimator.n_features_in_
@property
def feature_names_in_(self):
attr = getattr(self, "_trained_estimator", None)
attr = attr and getattr(attr, "feature_names_in_", None)
if attr is not None:
return attr
return getattr(self, "_feature_names_in_", None)
@property
def feature_importances_(self):
attr = getattr(self, "_trained_estimator", None)
attr = attr and getattr(attr, "feature_importances_", None)
return attr
@property
def time_to_find_best_model(self) -> float:
"""Time taken to find best model in seconds."""
@ -1118,16 +1132,22 @@ class AutoML(BaseEstimator):
X, y, self._state.task
)
self._label_transformer = self._transformer.label_transformer
if hasattr(self._label_transformer, "label_list"):
self._state.fit_kwargs.update(
{"label_list": self._label_transformer.label_list}
)
elif self._state.task == TOKENCLASSIFICATION:
if "label_list" not in self._state.fit_kwargs:
if self._state.task == TOKENCLASSIFICATION:
if hasattr(self._label_transformer, "label_list"):
self._state.fit_kwargs.update(
{"label_list": self._label_transformer.label_list}
)
elif "label_list" not in self._state.fit_kwargs:
for each_fit_kwargs in self._state.fit_kwargs_by_estimator.values():
assert (
"label_list" in each_fit_kwargs
), "For the token-classification task, you must either (1) pass token labels; or (2) pass id labels and the label list. Please refer to the documentation for more details: https://microsoft.github.io/FLAML/docs/Examples/AutoML-NLP#a-simple-token-classification-example"
), "For the token-classification task, you must either (1) pass token labels; or (2) pass id labels and the label list. "
"Please refer to the documentation for more details: https://microsoft.github.io/FLAML/docs/Examples/AutoML-NLP#a-simple-token-classification-example"
self._feature_names_in_ = (
self._X_train_all.columns.to_list()
if hasattr(self._X_train_all, "columns")
else None
)
self._sample_weight_full = self._state.fit_kwargs.get(
"sample_weight"

View File

@ -134,6 +134,43 @@ class BaseEstimator:
"""Trained model after fit() is called, or None before fit() is called."""
return self._model
@property
def feature_names_in_(self):
"""
if self._model has attribute feature_names_in_, return it.
otherwise, if self._model has attribute feature_name_, return it.
otherwise, if self._model has attribute feature_names, return it.
otherwise, if self._model has method get_booster, return the feature names.
otherwise, return None.
"""
if hasattr(self._model, "feature_names_in_"): # for sklearn, xgboost>=1.6
return self._model.feature_names_in_
if hasattr(self._model, "feature_name_"): # for lightgbm
return self._model.feature_name_
if hasattr(self._model, "feature_names"): # for XGBoostEstimator
return self._model.feature_names
if hasattr(self._model, "get_booster"):
# get feature names for xgboost<1.6
# https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.feature_names
booster = self._model.get_booster()
return booster.feature_names
return None
@property
def feature_importances_(self):
"""
if self._model has attribute feature_importances_, return it.
otherwise, if self._model has attribute coef_, return it.
otherwise, return None.
"""
if hasattr(self._model, "feature_importances_"):
# for sklearn, lightgbm, catboost, xgboost
return self._model.feature_importances_
elif hasattr(self._model, "coef_"): # for linear models
return self._model.coef_
else:
return None
def _preprocess(self, X):
return X

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -115,6 +115,8 @@ class TestClassification(unittest.TestCase):
"ensemble": True,
}
automl.fit(X, y, **automl_settings)
print(automl.feature_names_in_)
print(automl.feature_importances_)
del automl
automl = AutoML()
@ -246,6 +248,8 @@ class TestClassification(unittest.TestCase):
)
automl = AutoML()
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl.feature_names_in_)
print(automl.feature_importances_)
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-U", "xgboost", "--user"]
)

View File

@ -86,6 +86,8 @@ def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)
print(automl.feature_names_in_)
print(automl.feature_importances_)
if budget < performance_check_budget:
automl.fit(X_train=X_train, y_train=y_train, ensemble=True, **settings)

View File

@ -69,17 +69,29 @@ def test_prep():
lr = LRL2Classifier()
lr.fit(X, y)
lr.predict(X)
print(lr.feature_names_in_)
print(lr.feature_importances_)
lgbm = LGBMEstimator(n_estimators=4)
lgbm.fit(X, y)
print(lgbm.feature_names_in_)
print(lgbm.feature_importances_)
cat = CatBoostEstimator(n_estimators=4)
cat.fit(X, y)
print(cat.feature_names_in_)
print(cat.feature_importances_)
knn = KNeighborsEstimator(task="regression")
knn.fit(X, y)
print(knn.feature_names_in_)
print(knn.feature_importances_)
xgb = XGBoostEstimator(n_estimators=4, max_leaves=4)
xgb.fit(X, y)
xgb.predict(X)
print(xgb.feature_names_in_)
print(xgb.feature_importances_)
rf = RandomForestEstimator(task="regression", n_estimators=4, criterion="gini")
rf.fit(X, y)
print(rf.feature_names_in_)
print(rf.feature_importances_)
prophet = Prophet()
try:
@ -115,3 +127,9 @@ def test_prep():
lgbm.predict(X[:2])
lgbm.fit(X, y, period=2)
lgbm.predict(X[:2])
print(lgbm.feature_names_in_)
print(lgbm.feature_importances_)
if __name__ == "__main__":
test_prep()

View File

@ -14,7 +14,7 @@ settings = {
"time_budget": 60, # total running time in seconds
"metric": 'r2', # primary metrics for regression can be chosen from: ['mae','mse','r2']
"estimator_list": ['lgbm'], # list of ML learners; we tune lightgbm in this example
"task": 'regression', # task type
"task": 'regression', # task type
"log_file_name": 'houses_experiment.log', # flaml log file
"seed": 7654321, # random seed
}
@ -89,7 +89,7 @@ print(automl.model.estimator)
```python
import matplotlib.pyplot as plt
plt.barh(automl.model.estimator.feature_name_, automl.model.estimator.feature_importances_)
plt.barh(automl.feature_names_in_, automl.feature_importances_)
```
![png](../Use-Cases/images/feature_importance.png)

View File

@ -14,7 +14,7 @@ settings = {
"time_budget": 60, # total running time in seconds
"metric": 'r2', # primary metrics for regression can be chosen from: ['mae','mse','r2']
"estimator_list": ['xgboost'], # list of ML learners; we tune XGBoost in this example
"task": 'regression', # task type
"task": 'regression', # task type
"log_file_name": 'houses_experiment.log', # flaml log file
"seed": 7654321, # random seed
}
@ -119,7 +119,7 @@ print(automl.model.estimator)
```python
import matplotlib.pyplot as plt
plt.barh(X_train.columns, automl.model.estimator.feature_importances_)
plt.barh(automl.feature_names_in_, automl.feature_importances_)
```
![png](images/xgb_feature_importance.png)