example update (#281)

* example update

* bump version to 0.7.2

* notebook update
This commit is contained in:
Chi Wang 2021-11-12 22:29:33 -08:00 committed by GitHub
parent 92ebd1f7f9
commit 59083fbdcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 79 deletions

View File

@ -88,7 +88,7 @@ automl_settings = {
"time_budget": 10, # in seconds
"metric": 'accuracy',
"task": 'classification',
"log_file_name": "test/iris.log",
"log_file_name": "iris.log",
}
X_train, y_train = load_iris(return_X_y=True)
# Train with labeled input data
@ -112,7 +112,7 @@ automl_settings = {
"time_budget": 10, # in seconds
"metric": 'r2',
"task": 'regression',
"log_file_name": "test/california.log",
"log_file_name": "california.log",
}
X_train, y_train = fetch_california_housing(return_X_y=True)
# Train with labeled input data
@ -124,7 +124,7 @@ print(automl.predict(X_train))
print(automl.model.estimator)
```
* Time series forecasting.
* A basic time series forecasting example.
```python
# pip install flaml[ts_forecast]
@ -137,7 +137,7 @@ automl.fit(X_train=X_train[:72], # a single column of timestamp
y_train=y_train, # value for each timestamp
period=12, # time horizon to forecast, e.g., 12 months
task='ts_forecast', time_budget=15, # time budget in seconds
log_file_name="test/ts_forecast.log",
log_file_name="ts_forecast.log",
)
print(automl.predict(X_train[72:]))
```

View File

@ -1 +1 @@
__version__ = "0.7.1"
__version__ = "0.7.2"

View File

@ -2,35 +2,37 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pytorch model tuning example on CIFAR10\n",
"This notebook uses flaml to tune a pytorch model on CIFAR10. It is modified based on [this example](https://docs.ray.io/en/master/tune/examples/cifar10_pytorch.html).\n",
"\n",
"**Requirements.** This notebook requires:"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"!pip install torchvision flaml[blendsearch,ray];"
],
"outputs": [],
"metadata": {
"tags": []
}
},
"outputs": [],
"source": [
"!pip install torchvision flaml[blendsearch,ray];"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Specification"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
@ -60,20 +62,20 @@
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_data(data_dir=\"data\"):\n",
" transform = transforms.Compose([\n",
@ -88,20 +90,20 @@
" root=data_dir, train=False, download=True, transform=transform)\n",
"\n",
" return trainset, testset"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray import tune\n",
"\n",
@ -199,20 +201,20 @@
"\n",
" tune.report(loss=(val_loss / val_steps), accuracy=correct / total)\n",
" print(\"Finished Training\")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test Accuracy"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _test_accuracy(net, device=\"cpu\"):\n",
" trainset, testset = load_data()\n",
@ -232,41 +234,41 @@
" correct += (predicted == labels).sum().item()\n",
"\n",
" return correct / total"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter Optimization"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import flaml\n",
"import ray\n",
"import os\n",
"\n",
"data_dir = os.path.abspath(\"data\")\n",
"load_data(data_dir) # Download data for all trials before starting the run"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Search space"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"max_num_epoch = 100\n",
"config = {\n",
@ -276,32 +278,32 @@
" \"num_epochs\": tune.loguniform(1, max_num_epoch),\n",
" \"batch_size\": tune.randint(1, 5) # log transformed with base 2\n",
"}"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_budget_s = 600 # time budget in seconds\n",
"gpus_per_trial = 0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu\n",
"num_samples = 500 # maximal number of trials\n",
"np.random.seed(7654321)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Launch the tuning"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"start_time = time.time()\n",
@ -318,16 +320,29 @@
" local_dir='logs/',\n",
" num_samples=num_samples,\n",
" time_budget_s=time_budget_s,\n",
" use_ray=True)\n",
"\n",
"ray.shutdown()"
],
"outputs": [],
"metadata": {}
" use_ray=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"#trials=44\n",
"time=1193.913584947586\n",
"Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}\n",
"Best trial final validation loss: 1.0694482081472874\n",
"Best trial final validation accuracy: 0.6389\n",
"Files already downloaded and verified\n",
"Files already downloaded and verified\n",
"Best trial test set accuracy: 0.6294\n"
]
}
],
"source": [
"print(f\"#trials={len(result.trials)}\")\n",
"print(f\"time={time.time()-start_time}\")\n",
@ -354,30 +369,16 @@
"\n",
"test_acc = _test_accuracy(best_trained_model, device)\n",
"print(\"Best trial test set accuracy: {}\".format(test_acc))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"#trials=44\n",
"time=1193.913584947586\n",
"Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}\n",
"Best trial final validation loss: 1.0694482081472874\n",
"Best trial final validation accuracy: 0.6389\n",
"Files already downloaded and verified\n",
"Files already downloaded and verified\n",
"Best trial test set accuracy: 0.6294\n"
]
}
],
"metadata": {}
]
}
],
"metadata": {
"interpreter": {
"hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.10 64-bit ('venv': venv)"
"display_name": "Python 3.8.10 64-bit ('venv': venv)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -389,17 +390,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.7.12"
},
"metadata": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
},
"interpreter": {
"hash": "f7771e6a3915580179405189f5aa4eb9047494cbe4e005b29b851351b54902f6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}