test -> val; docstr (#300)

* rename test -> val in custom metric function
* add an example in docstr
resolve #299
This commit is contained in:
Chi Wang 2021-11-22 22:17:29 -08:00 committed by GitHub
parent ea6d28d7bd
commit 85e21864ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 207 additions and 62 deletions

View File

@ -416,14 +416,38 @@ class AutoML(BaseEstimator):
.. code-block:: python
def custom_metric(
X_test, y_test, estimator, labels,
X_train, y_train, weight_test=None, weight_train=None,
config=None, groups_test=None, groups_train=None,
X_val, y_val, estimator, labels,
X_train, y_train, weight_val=None, weight_train=None,
config=None, groups_val=None, groups_train=None,
):
return metric_to_minimize, metrics_to_log
which returns a float number as the minimization objective,
and a dictionary as the metrics to log.
and a dictionary as the metrics to log. E.g.,
.. code-block:: python
def custom_metric(
X_val, y_val, estimator, labels,
X_train, y_train, weight_val=None, weight_train=None,
**args,
):
from sklearn.metrics import log_loss
import time
start = time.time()
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast', 'rank',
'seq-classification', 'seq-regression'.
@ -1641,14 +1665,38 @@ class AutoML(BaseEstimator):
.. code-block:: python
def custom_metric(
X_test, y_test, estimator, labels,
X_train, y_train, weight_test=None, weight_train=None,
config=None, groups_test=None, groups_train=None,
X_val, y_val, estimator, labels,
X_train, y_train, weight_val=None, weight_train=None,
config=None, groups_val=None, groups_train=None,
):
return metric_to_minimize, metrics_to_log
which returns a float number as the minimization objective,
and a dictionary as the metrics to log.
and a dictionary as the metrics to log. E.g.,
.. code-block:: python
def custom_metric(
X_val, y_val, estimator, labels,
X_train, y_train, weight_val=None, weight_train=None,
**args,
):
from sklearn.metrics import log_loss
import time
start = time.time()
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast', 'rank',
'seq-classification', 'seq-regression'.

View File

@ -189,10 +189,10 @@ def _eval_estimator(
estimator,
X_train,
y_train,
X_test,
y_test,
weight_test,
groups_test,
X_val,
y_val,
weight_val,
groups_val,
eval_metric,
obj,
labels=None,
@ -201,10 +201,10 @@ def _eval_estimator(
):
if isinstance(eval_metric, str):
pred_start = time.time()
test_pred_y = get_y_pred(estimator, X_test, eval_metric, obj)
pred_time = (time.time() - pred_start) / X_test.shape[0]
test_loss = sklearn_metric_loss_score(
eval_metric, test_pred_y, y_test, labels, weight_test, groups_test
val_pred_y = get_y_pred(estimator, X_val, eval_metric, obj)
pred_time = (time.time() - pred_start) / X_val.shape[0]
val_loss = sklearn_metric_loss_score(
eval_metric, val_pred_y, y_val, labels, weight_val, groups_val
)
metric_for_logging = {}
if log_training_metric:
@ -218,34 +218,34 @@ def _eval_estimator(
fit_kwargs.get("groups"),
)
else: # customized metric function
test_loss, metric_for_logging = eval_metric(
X_test,
y_test,
val_loss, metric_for_logging = eval_metric(
X_val,
y_val,
estimator,
labels,
X_train,
y_train,
weight_test,
weight_val,
fit_kwargs.get("sample_weight"),
config,
groups_test,
groups_val,
fit_kwargs.get("groups"),
)
pred_time = metric_for_logging.get("pred_time", 0)
test_pred_y = None
# eval_metric may return test_pred_y but not necessarily. Setting None for now.
return test_loss, metric_for_logging, pred_time, test_pred_y
val_pred_y = None
# eval_metric may return val_pred_y but not necessarily. Setting None for now.
return val_loss, metric_for_logging, pred_time, val_pred_y
def get_test_loss(
def get_val_loss(
config,
estimator,
X_train,
y_train,
X_test,
y_test,
weight_test,
groups_test,
X_val,
y_val,
weight_val,
groups_val,
eval_metric,
obj,
labels=None,
@ -255,20 +255,20 @@ def get_test_loss(
):
start = time.time()
# if groups_test is not None:
# fit_kwargs['groups_val'] = groups_test
# fit_kwargs['X_val'] = X_test
# fit_kwargs['y_val'] = y_test
# if groups_val is not None:
# fit_kwargs['groups_val'] = groups_val
# fit_kwargs['X_val'] = X_val
# fit_kwargs['y_val'] = y_val
estimator.fit(X_train, y_train, budget, **fit_kwargs)
test_loss, metric_for_logging, pred_time, _ = _eval_estimator(
val_loss, metric_for_logging, pred_time, _ = _eval_estimator(
config,
estimator,
X_train,
y_train,
X_test,
y_test,
weight_test,
groups_test,
X_val,
y_val,
weight_val,
groups_val,
eval_metric,
obj,
labels,
@ -276,7 +276,7 @@ def get_test_loss(
fit_kwargs,
)
train_time = time.time() - start
return test_loss, metric_for_logging, train_time, pred_time
return val_loss, metric_for_logging, train_time, pred_time
def evaluate_model_CV(
@ -349,7 +349,7 @@ def evaluate_model_CV(
groups_val = groups[val_index]
else:
groups_val = None
val_loss_i, metric_i, train_time_i, pred_time_i = get_test_loss(
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
config,
estimator,
X_train,
@ -427,7 +427,7 @@ def compute_estimator(
n_jobs=n_jobs,
)
if "holdout" == eval_method:
val_loss, metric_for_logging, train_time, pred_time = get_test_loss(
val_loss, metric_for_logging, train_time, pred_time = get_val_loss(
config_dic,
estimator,
X_train,

View File

@ -34,9 +34,106 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: flaml[notebook] in /usr/local/lib/python3.9/site-packages (0.7.1)\n",
"Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.7.2)\n",
"Requirement already satisfied: lightgbm>=2.3.1 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.3.1)\n",
"Requirement already satisfied: pandas>=1.1.4 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.3.4)\n",
"Requirement already satisfied: NumPy>=1.16.2 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.21.4)\n",
"Requirement already satisfied: xgboost<=1.3.3,>=0.90 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.3.3)\n",
"Requirement already satisfied: scikit-learn>=0.24 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.1)\n",
"Requirement already satisfied: catboost>=0.26 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.3)\n",
"Requirement already satisfied: jupyter in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (1.0.0)\n",
"Requirement already satisfied: rgf-python in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.11.0)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (3.5.0)\n",
"Requirement already satisfied: openml==0.10.2 in /usr/local/lib/python3.9/site-packages (from flaml[notebook]) (0.10.2)\n",
"Requirement already satisfied: xmltodict in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (0.12.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.26.0)\n",
"Requirement already satisfied: liac-arff>=2.4.0 in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.5.0)\n",
"Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/site-packages (from openml==0.10.2->flaml[notebook]) (2.8.2)\n",
"Requirement already satisfied: plotly in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (5.4.0)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (1.16.0)\n",
"Requirement already satisfied: graphviz in /usr/local/lib/python3.9/site-packages (from catboost>=0.26->flaml[notebook]) (0.18.2)\n",
"Requirement already satisfied: wheel in /usr/local/lib/python3.9/site-packages (from lightgbm>=2.3.1->flaml[notebook]) (0.37.0)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.9/site-packages (from pandas>=1.1.4->flaml[notebook]) (2021.3)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.24->flaml[notebook]) (1.1.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.24->flaml[notebook]) (3.0.0)\n",
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.5.1)\n",
"Requirement already satisfied: qtconsole in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (5.2.0)\n",
"Requirement already satisfied: notebook in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.4.6)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (7.6.5)\n",
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.4.0)\n",
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.9/site-packages (from jupyter->flaml[notebook]) (6.3.0)\n",
"Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (3.0.6)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (8.4.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (1.3.2)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (4.28.1)\n",
"Requirement already satisfied: setuptools-scm>=4 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (6.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (0.11.0)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->flaml[notebook]) (21.3)\n",
"Requirement already satisfied: tomli>=1.0.0 in /usr/local/lib/python3.9/site-packages (from setuptools-scm>=4->matplotlib->flaml[notebook]) (1.2.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.9/site-packages (from setuptools-scm>=4->matplotlib->flaml[notebook]) (57.5.0)\n",
"Requirement already satisfied: traitlets<6.0,>=5.1.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (5.1.1)\n",
"Requirement already satisfied: tornado<7.0,>=4.2 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (6.1)\n",
"Requirement already satisfied: jupyter-client<8.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (7.0.6)\n",
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (7.29.0)\n",
"Requirement already satisfied: debugpy<2.0,>=1.0.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (1.5.1)\n",
"Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /usr/local/lib/python3.9/site-packages (from ipykernel->jupyter->flaml[notebook]) (0.1.3)\n",
"Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (3.5.2)\n",
"Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (1.0.2)\n",
"Requirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (5.1.3)\n",
"Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.9/site-packages (from ipywidgets->jupyter->flaml[notebook]) (0.2.0)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.9/site-packages (from jupyter-console->jupyter->flaml[notebook]) (3.0.22)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.9/site-packages (from jupyter-console->jupyter->flaml[notebook]) (2.10.0)\n",
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (4.9.1)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (4.1.0)\n",
"Requirement already satisfied: jinja2>=2.4 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (3.0.3)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.3)\n",
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.5.9)\n",
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.1.2)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (1.5.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.7.1)\n",
"Requirement already satisfied: testpath in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.5.0)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.9/site-packages (from nbconvert->jupyter->flaml[notebook]) (0.8.4)\n",
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (1.8.0)\n",
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (0.12.0)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (22.3.0)\n",
"Requirement already satisfied: nest-asyncio>=1.5 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (1.5.1)\n",
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (21.1.0)\n",
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.9/site-packages (from notebook->jupyter->flaml[notebook]) (0.12.1)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/site-packages (from plotly->catboost>=0.26->flaml[notebook]) (8.0.1)\n",
"Requirement already satisfied: qtpy in /usr/local/lib/python3.9/site-packages (from qtconsole->jupyter->flaml[notebook]) (1.11.2)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (1.26.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (2021.10.8)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (3.3)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.9/site-packages (from requests->openml==0.10.2->flaml[notebook]) (2.0.7)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (5.1.0)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (4.8.0)\n",
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.18.1)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.7.5)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.9/site-packages (from ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.2.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/site-packages (from jinja2>=2.4->nbconvert->jupyter->flaml[notebook]) (2.0.1)\n",
"Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.9/site-packages (from nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (4.2.1)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.9/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter->flaml[notebook]) (0.2.5)\n",
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.9/site-packages (from terminado>=0.8.3->notebook->jupyter->flaml[notebook]) (0.7.0)\n",
"Requirement already satisfied: cffi>=1.0.0 in /usr/local/lib/python3.9/site-packages (from argon2-cffi->notebook->jupyter->flaml[notebook]) (1.15.0)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.9/site-packages (from bleach->nbconvert->jupyter->flaml[notebook]) (0.5.1)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.9/site-packages (from cffi>=1.0.0->argon2-cffi->notebook->jupyter->flaml[notebook]) (2.21)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.9/site-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter->flaml[notebook]) (0.8.2)\n",
"Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (21.2.0)\n",
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->jupyter->flaml[notebook]) (0.18.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
"\u001b[33mWARNING: You are using pip version 21.3; however, version 21.3.1 is available.\n",
"You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n"
]
}
],
"source": [
"!pip install flaml[notebook];\n",
"# from v0.6.6, catboost is made an optional dependency to build conda package.\n",
@ -977,7 +1074,7 @@
"source": [
"## 5. Customized Metric\n",
"\n",
"It's also easy to customize the optimization metric. As an example, we demonstrate with a custom metric function which combines training loss and test loss as the final loss to minimize."
"It's also easy to customize the optimization metric. As an example, we demonstrate with a custom metric function which combines training loss and validation loss as the final loss to minimize."
]
},
{
@ -986,22 +1083,22 @@
"metadata": {},
"outputs": [],
"source": [
"def custom_metric(X_test, y_test, estimator, labels, X_train, y_train,\n",
" weight_test=None, weight_train=None, config=None,\n",
" groups_test=None, groups_train=None):\n",
"def custom_metric(X_val, y_val, estimator, labels, X_train, y_train,\n",
" weight_val=None, weight_train=None, config=None,\n",
" groups_val=None, groups_train=None):\n",
" from sklearn.metrics import log_loss\n",
" import time\n",
" start = time.time()\n",
" y_pred = estimator.predict_proba(X_test)\n",
" pred_time = (time.time() - start) / len(X_test)\n",
" test_loss = log_loss(y_test, y_pred, labels=labels,\n",
" sample_weight=weight_test)\n",
" y_pred = estimator.predict_proba(X_val)\n",
" pred_time = (time.time() - start) / len(X_val)\n",
" val_loss = log_loss(y_val, y_pred, labels=labels,\n",
" sample_weight=weight_val)\n",
" y_pred = estimator.predict_proba(X_train)\n",
" train_loss = log_loss(y_train, y_pred, labels=labels,\n",
" sample_weight=weight_train)\n",
" alpha = 0.5\n",
" return test_loss * (1 + alpha) - alpha * train_loss, {\n",
" \"test_loss\": test_loss, \"train_loss\": train_loss, \"pred_time\": pred_time\n",
" return val_loss * (1 + alpha) - alpha * train_loss, {\n",
" \"val_loss\": val_loss, \"train_loss\": train_loss, \"pred_time\": pred_time\n",
" }\n",
" # two elements are returned:\n",
" # the first element is the metric to minimize as a float number,\n",

View File

@ -98,30 +98,30 @@ class MyLargeLGBM(LGBMEstimator):
def custom_metric(
X_test,
y_test,
X_val,
y_val,
estimator,
labels,
X_train,
y_train,
weight_test=None,
weight_val=None,
weight_train=None,
config=None,
groups_test=None,
groups_val=None,
groups_train=None,
):
from sklearn.metrics import log_loss
import time
start = time.time()
y_pred = estimator.predict_proba(X_test)
pred_time = (time.time() - start) / len(X_test)
test_loss = log_loss(y_test, y_pred, labels=labels, sample_weight=weight_test)
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return test_loss * (1 + alpha) - alpha * train_loss, {
"test_loss": test_loss,
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}