forked from mindspore-Ecosystem/mindspore
add DeepBSDE
This commit is contained in:
parent
7dffa5096c
commit
25e3a360d3
|
@ -0,0 +1,182 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [DeepBSDE Description](#DeepBSDE-description)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Model Description](#Model-Description)
|
||||
- [Evaluation Performance](#Evaluation-Performance)
|
||||
- [Inference Performance](#Inference-Performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [DeepBSDE Description](#contents)
|
||||
|
||||
DeepBSDE is a power of deep neural networks by developing a strategy for solving a large class of high-dimensional nonlinear PDEs using deep learning. The class of PDEs that we deal with is (nonlinear) parabolic PDEs.
|
||||
|
||||
[paper](https:#www.pnas.org/content/115/34/8505): Han J , Arnulf J , Weinan E . Solving high-dimensional partial differential equations using deep learning[J]. Proceedings of the National Academy of Sciences, 2018:201718942-.
|
||||
|
||||
## [HJB equation](#Contents)
|
||||
|
||||
Hamilton–Jacobi–Bellman Equation which is the term curse of dimensionality was first used explicitly by Richard Bellman in the context of dynamic programming, which has now become the cornerstone in many areas such as economics, behavioral science, computer science, and even biology, where intelligent decision making is the main issue.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https:#www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https:#www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https:#www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
```shell
|
||||
# Running training example
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python train.py --config_path=./config/HJBLQ_config.yaml
|
||||
OR
|
||||
bash ./scripts/run_train.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)
|
||||
|
||||
# Running evaluation example
|
||||
python eval.py --config_path=./config/HJBLQ_config.yaml
|
||||
OR
|
||||
bash ./scripts/run_eval.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
.
|
||||
├── config
|
||||
│ └── HJBLQ_config.yaml # default config for HJB equation.
|
||||
├── src
|
||||
│ ├── config.py # config parse script.
|
||||
│ ├── equation.py # equation definition and dataset helper.
|
||||
│ ├── eval_utils.py # evaluation callback and evaluation utils.
|
||||
│ └── net.py # DeepBSDE network structure.
|
||||
├── eval.py # evaluation API entry.
|
||||
├── export.py # export models API entry.
|
||||
├── README_CN.md
|
||||
├── README.md
|
||||
├── requirements.txt # requirements of third party package.
|
||||
└── train.py # train API entry.
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in `CONFIG_YAML`
|
||||
|
||||
- config for HBJ
|
||||
|
||||
```python
|
||||
# eqn config
|
||||
eqn_name: "HJBLQ" # Equation function name.
|
||||
total_time: 1.0 # The total time of equation function.
|
||||
dim: 100 # Hidden layer dims.
|
||||
num_time_interval: 20 # Number of interval times.
|
||||
|
||||
# net config
|
||||
y_init_range: [0, 1] # The y_init random initialization range.
|
||||
num_hiddens: [110, 110] # A list of hidden layer's filter number.
|
||||
lr_values: [0.01, 0.01] # lr_values of piecewise_constant_lr.
|
||||
lr_boundaries: [1000] # lr_boundaries of piecewise_constant_lr.
|
||||
num_iterations: 2000 # Iterations numbers.
|
||||
batch_size: 64 # batch_size when training.
|
||||
valid_size: 256 # batch_size when evaluation.
|
||||
logging_frequency: 100 # logging and evaluation callback frequency.
|
||||
|
||||
# other config
|
||||
device_target: "GPU" # Device where the code will be implemented. Optional values is GPU.
|
||||
log_dir: "./logs" # The path of log saving.
|
||||
file_format: "MINDIR" # Export model type.
|
||||
```
|
||||
|
||||
For more configuration details, please refer the yaml file `./config/HJBLQ_config.yaml`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Running on GPU
|
||||
|
||||
```bash
|
||||
python train.py --config_path=./config/HJBLQ_config.yaml > train.log 2>&1 &
|
||||
```
|
||||
|
||||
- The python command above will run in the background, you can view the results through the file `train.log`。
|
||||
|
||||
The loss value can be achieved as follows:
|
||||
|
||||
```log
|
||||
epoch: 1 step: 100, loss is 245.3738
|
||||
epoch time: 26883.370 ms, per step time: 268.834 ms
|
||||
total step: 100, eval loss: 1179.300, Y0: 1.400, elapsed time: 34
|
||||
epcoh: 2 step: 100, loss is 149.6593
|
||||
epoch time: 3184.401 ms, per step time: 32.877 ms
|
||||
total step: 200, eval loss: 659.457, Y0: 1.693, elapsed time: 37
|
||||
...
|
||||
```
|
||||
|
||||
After training, you'll get the last checkpoint file under the folder `log_dir` in config.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
- Evaluation on GPU
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Such as `./log/deepbsde_HJBLQ_end.ckpt`
|
||||
|
||||
```bash
|
||||
python eval.py --config_path=./config/HJBLQ_config.yaml > eval.log 2>&1 &
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The error of evaluation is as follows:
|
||||
|
||||
```log
|
||||
eval loss: 5.146923065185527, Y0: 4.59813117980957
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Evaluation Performance](#contents)
|
||||
|
||||
| Parameters | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | DeepBSDE |
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 7/5/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 | |
|
||||
| Training Parameters | step=2000, see `./config/HJBLQ_config.yaml` for details |
|
||||
| Optimizer | Adam | |
|
||||
| Loss | 2.11 |
|
||||
| Speed | 32ms/step |
|
||||
| Total time | 3 min |
|
||||
| Parameters | 650K |
|
||||
| Checkpoint for Fine tuning | 7.8M (.ckpt file) |
|
||||
|
||||
## [Inference Performance](#contents)
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | -------------------------------------------- |
|
||||
| Model Version | DeepBSDE |
|
||||
| Resource | NV SMX2 V100-32G |
|
||||
| uploaded Date | 7/5/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| outputs | eval loss & Y0 |
|
||||
| Y0 | Y0: 4.59 |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random in equation.py,which can be set seed to fixed randomness.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https:#gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,41 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
|
||||
|
||||
# eqn config
|
||||
eqn_name: "HJBLQ"
|
||||
total_time: 1.0
|
||||
dim: 100
|
||||
num_time_interval: 20
|
||||
|
||||
# net config
|
||||
y_init_range: [0, 1]
|
||||
num_hiddens: [110, 110]
|
||||
lr_values: [0.01, 0.01]
|
||||
lr_boundaries: [1000]
|
||||
num_iterations: 2000
|
||||
batch_size: 64
|
||||
valid_size: 256
|
||||
logging_frequency: 100
|
||||
|
||||
# other config
|
||||
device_target: "GPU"
|
||||
log_dir: "./logs"
|
||||
file_format: "MINDIR"
|
||||
|
||||
---
|
||||
|
||||
# Help description for each configuration
|
||||
eqn_name: "Equation function name."
|
||||
total_time: "The total time of equation function."
|
||||
dim: "Hidden layer dims."
|
||||
num_time_interval: "Number of interval times."
|
||||
y_init_range: "The y_init random initialization range."
|
||||
num_hiddens: "A list of hidden layer's filter number."
|
||||
lr_values: "lr_values of piecewise_constant_lr."
|
||||
lr_boundaries: "lr_boundaries of piecewise_constant_lr."
|
||||
num_iterations: "Iterations numbers."
|
||||
batch_size: "batch_size when training."
|
||||
valid_size: "batch_size when evaluation."
|
||||
logging_frequency: "logging and evaluation callback frequency."
|
||||
device_target: "Device where the code will be implemented. Optional values is GPU."
|
||||
log_dir: "The path of log saving."
|
||||
file_format: "Export model type."
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DeepBSDE evaluation script"""
|
||||
import os
|
||||
from mindspore import context, load_checkpoint
|
||||
from src.net import DeepBSDE, WithLossCell
|
||||
from src.config import config
|
||||
from src.equation import get_bsde
|
||||
from src.eval_utils import apply_eval
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_end.ckpt".format(config.eqn_name))
|
||||
bsde = get_bsde(config)
|
||||
print('Begin to solve', config.eqn_name)
|
||||
net = DeepBSDE(config, bsde)
|
||||
net_with_loss = WithLossCell(net)
|
||||
load_checkpoint(config.ckpt_path, net=net_with_loss)
|
||||
eval_param = {"model": net_with_loss, "valid_data": bsde.sample(config.valid_size)}
|
||||
loss, y_init = apply_eval(eval_param)
|
||||
print("eval loss: {}, Y0: {}".format(loss, y_init))
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DeepBSDE export model script"""
|
||||
import os
|
||||
from mindspore import context, load_checkpoint, export, Tensor
|
||||
from src.net import DeepBSDE
|
||||
from src.config import config
|
||||
from src.equation import get_bsde
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_end.ckpt".format(config.eqn_name))
|
||||
bsde = get_bsde(config)
|
||||
print('Begin to solve', config.eqn_name)
|
||||
net = DeepBSDE(config, bsde)
|
||||
load_checkpoint(config.ckpt_path, net=net)
|
||||
dw, x = bsde.sample(config.valid_size)
|
||||
export(net, Tensor(dw), Tensor(x), file_name="deepbsde_{}".format(config.eqn_name), file_format=config.file_format)
|
|
@ -0,0 +1,2 @@
|
|||
scipy >= 1.5.2
|
||||
PyYAML
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash ./scripts/run_eval.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash ./scripts/run_eval.sh ./config/HJBLQ_config.yaml 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
export DEVICE_ID=0
|
||||
if [ $# == 2 ];
|
||||
then
|
||||
export DEVICE_ID=$2
|
||||
fi
|
||||
|
||||
config_yaml=$(get_real_path $1)
|
||||
|
||||
nohup python ${PROJECT_DIR}/../eval.py --config_path=$config_yaml > eval.log 2>&1 &
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash ./scripts/run_train.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash ./scripts/run_train.sh ./config/HJBLQ_config.yaml 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
export DEVICE_ID=0
|
||||
if [ $# == 2 ];
|
||||
then
|
||||
export DEVICE_ID=$2
|
||||
fi
|
||||
|
||||
config_yaml=$(get_real_path $1)
|
||||
|
||||
nohup python ${PROJECT_DIR}/../train.py --config_path=$config_yaml > train.log 2>&1 &
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Parse arguments"""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from pprint import pprint, pformat
|
||||
import yaml
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration namespace. Convert dictionary to members.
|
||||
"""
|
||||
def __init__(self, cfg_dict):
|
||||
for k, v in cfg_dict.items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
|
||||
else:
|
||||
setattr(self, k, Config(v) if isinstance(v, dict) else v)
|
||||
|
||||
def __str__(self):
|
||||
return pformat(self.__dict__)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
|
||||
"""
|
||||
Parse command line arguments to the configuration according to the default yaml.
|
||||
|
||||
Args:
|
||||
parser: Parent parser.
|
||||
cfg: Base configuration.
|
||||
helper: Helper description.
|
||||
cfg_path: Path to the default yaml config.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
|
||||
parents=[parser])
|
||||
helper = {} if helper is None else helper
|
||||
choices = {} if choices is None else choices
|
||||
for item in cfg:
|
||||
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
|
||||
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
|
||||
choice = choices[item] if item in choices else None
|
||||
if isinstance(cfg[item], bool):
|
||||
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
else:
|
||||
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
|
||||
help=help_description)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def parse_yaml(yaml_path):
|
||||
"""
|
||||
Parse the yaml config file.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the yaml config.
|
||||
"""
|
||||
with open(yaml_path, 'r') as fin:
|
||||
try:
|
||||
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
|
||||
cfgs = [x for x in cfgs]
|
||||
if len(cfgs) == 1:
|
||||
cfg_helper = {}
|
||||
cfg = cfgs[0]
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 2:
|
||||
cfg, cfg_helper = cfgs
|
||||
cfg_choices = {}
|
||||
elif len(cfgs) == 3:
|
||||
cfg, cfg_helper, cfg_choices = cfgs
|
||||
else:
|
||||
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
|
||||
print(cfg_helper)
|
||||
except:
|
||||
raise ValueError("Failed to parse yaml")
|
||||
return cfg, cfg_helper, cfg_choices
|
||||
|
||||
|
||||
def merge(args, cfg):
|
||||
"""
|
||||
Merge the base config from yaml file and command line arguments.
|
||||
|
||||
Args:
|
||||
args: Command line arguments.
|
||||
cfg: Base configuration.
|
||||
"""
|
||||
args_var = vars(args)
|
||||
for item in args_var:
|
||||
cfg[item] = args_var[item]
|
||||
return cfg
|
||||
|
||||
|
||||
def get_config():
|
||||
"""
|
||||
Get Config according to the yaml file and cli arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../config/HJBLQ_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
||||
pprint(default)
|
||||
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
|
||||
final_config = merge(args, default)
|
||||
return Config(final_config)
|
||||
|
||||
config = get_config()
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""equations for different PDE function."""
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import multivariate_normal as normal
|
||||
from mindspore import ops as P
|
||||
from mindspore import nn
|
||||
import mindspore.dataset as ds
|
||||
|
||||
class Equation():
|
||||
"""Base class for defining PDE related function."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.dim = cfg.dim
|
||||
self.total_time = cfg.total_time
|
||||
self.steps = cfg.num_iterations
|
||||
self.num_time_interval = cfg.num_time_interval
|
||||
self.delta_t = self.total_time / self.num_time_interval
|
||||
self.sqrt_delta_t = np.sqrt(self.delta_t)
|
||||
self.y_init = None
|
||||
self.num_sample = cfg.batch_size
|
||||
self.generator = P.Identity()
|
||||
self.terminal_condition = P.Identity()
|
||||
|
||||
def sample(self, num_sample):
|
||||
"""Sample forward SDE."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.sample(self.num_sample)
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
return ["dw", "x"]
|
||||
|
||||
def __len__(self):
|
||||
return self.steps
|
||||
|
||||
|
||||
class HJBLQ(Equation):
|
||||
"""HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
|
||||
def __init__(self, cfg):
|
||||
super(HJBLQ, self).__init__(cfg)
|
||||
self.x_init = np.zeros(self.dim)
|
||||
self.sigma = np.sqrt(2.0)
|
||||
self.generator = HJBLQGenerator(1.0)
|
||||
self.terminal_condition = HJBLQTerminalCondition()
|
||||
|
||||
def sample(self, num_sample):
|
||||
# draw random samples from a multivariate normal distribution
|
||||
dw_sample = normal.rvs(size=[num_sample,
|
||||
self.dim,
|
||||
self.num_time_interval]) * self.sqrt_delta_t # num_sample, dim, num_time_interval
|
||||
x_sample = np.zeros([num_sample, self.dim, self.num_time_interval + 1])
|
||||
for i in range(self.num_time_interval):
|
||||
x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
|
||||
return dw_sample.astype(np.float32), x_sample.astype(np.float32)
|
||||
|
||||
|
||||
class HJBLQGenerator(nn.Cell):
|
||||
"""Generator function for HJBLQ"""
|
||||
def __init__(self, lambd):
|
||||
super(HJBLQGenerator, self).__init__()
|
||||
self.lambd = lambd
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, t, x, y, z):
|
||||
res = -self.lambd * self.sum(self.square(z), 1)
|
||||
return res
|
||||
|
||||
|
||||
class HJBLQTerminalCondition(nn.Cell):
|
||||
"""Terminal condition for HJBLQ"""
|
||||
def __init__(self):
|
||||
super(HJBLQTerminalCondition, self).__init__()
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, t, x):
|
||||
res = P.log((1 + self.sum(self.square(x), 1)) / 2)
|
||||
return res
|
||||
|
||||
def get_bsde(cfg):
|
||||
bsde_dict = {"HJBLQ": HJBLQ(cfg)}
|
||||
return bsde_dict[cfg.eqn_name.upper()]
|
||||
|
||||
def create_dataset(bsde):
|
||||
"""Get generator dataset when training."""
|
||||
dataset = ds.GeneratorDataset(bsde, bsde.column_names)
|
||||
return dataset
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation callback when training"""
|
||||
|
||||
import time
|
||||
from mindspore import Tensor, save_checkpoint
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Evaluation callback when training.
|
||||
|
||||
Args:
|
||||
eval_param_dict (dict): evaluation parameters' configure dict.
|
||||
ckpt_path (str): save checkpoint path format, eg: "./logs/deepbsde_hjb_{}.ckpt".
|
||||
interval (int): run evaluation interval, default is 1.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> EvalCallBack(eval_function, eval_param_dict)
|
||||
"""
|
||||
|
||||
def __init__(self, eval_param_dict, ckpt_path, interval=1):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.eval_param_dict = eval_param_dict
|
||||
self.eval_function = apply_eval
|
||||
if interval < 1:
|
||||
raise ValueError("interval should >= 1.")
|
||||
self.interval = interval
|
||||
self.best_res = 0
|
||||
self.best_epoch = 0
|
||||
self.ckpt_path = ckpt_path
|
||||
self.start_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""Callback when epoch end."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_step = cb_params.cur_step_num
|
||||
if cur_step % self.interval == 0:
|
||||
loss, y_init = self.eval_function(self.eval_param_dict)
|
||||
elapsed_time = time.time() - self.start_time
|
||||
print("total step: {:4d}, eval loss: {:5.3f}, Y0: {:5.3f}, elapsed time: {:3.0f}".format(
|
||||
cur_step, loss, y_init, elapsed_time))
|
||||
|
||||
def end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
save_checkpoint(cb_params.train_network, self.ckpt_path.format("end"))
|
||||
|
||||
def apply_eval(eval_param):
|
||||
eval_model = eval_param["model"]
|
||||
dw, x = eval_param["valid_data"]
|
||||
eval_model.set_train(False)
|
||||
loss = eval_model(Tensor(dw), Tensor(x)).asnumpy()
|
||||
y_init = eval_model.net.y_init.asnumpy()[0]
|
||||
return loss, y_init
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define the network structure of DeepBSDE"""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import ops as P
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
|
||||
class DeepBSDE(nn.Cell):
|
||||
"""
|
||||
The network structure of DeepBSDE.
|
||||
|
||||
Args:
|
||||
cfg: configure settings.
|
||||
bsde(Cell): equation function
|
||||
"""
|
||||
def __init__(self, cfg, bsde):
|
||||
super(DeepBSDE, self).__init__()
|
||||
self.bsde = bsde
|
||||
self.delta_t = bsde.delta_t
|
||||
self.num_time_interval = bsde.num_time_interval
|
||||
self.dim = bsde.dim
|
||||
self.time_stamp = Tensor(np.arange(0, cfg.num_time_interval) * bsde.delta_t)
|
||||
self.y_init = Parameter(np.random.uniform(low=cfg.y_init_range[0],
|
||||
high=cfg.y_init_range[1],
|
||||
size=[1]).astype(np.float32))
|
||||
self.z_init = Parameter(np.random.uniform(low=-0.1, high=0.1, size=[1, cfg.dim]).astype(np.float32))
|
||||
|
||||
self.subnet = nn.CellList([FeedForwardSubNet(cfg.dim, cfg.num_hiddens)
|
||||
for _ in range(bsde.num_time_interval-1)])
|
||||
self.generator = bsde.generator
|
||||
self.matmul = P.MatMul()
|
||||
self.sum = P.ReduceSum(keep_dims=True)
|
||||
|
||||
def construct(self, dw, x):
|
||||
"""repeat FeedForwardSubNet (num_time_interval - 1) times."""
|
||||
all_one_vec = P.Ones()((P.shape(dw)[0], 1), mstype.float32)
|
||||
y = all_one_vec * self.y_init
|
||||
z = self.matmul(all_one_vec, self.z_init)
|
||||
|
||||
for t in range(0, self.num_time_interval - 1):
|
||||
y = y - self.delta_t * (self.generator(self.time_stamp[t], x[:, :, t], y, z)) + self.sum(z * dw[:, :, t], 1)
|
||||
z = self.subnet[t](x[:, :, t + 1]) / self.dim
|
||||
# terminal time
|
||||
y = y - self.delta_t * self.generator(self.time_stamp[-1], x[:, :, -2], y, z) + self.sum(z * dw[:, :, -1], 1)
|
||||
return y
|
||||
|
||||
|
||||
class FeedForwardSubNet(nn.Cell):
|
||||
"""
|
||||
Subnet to fit the spatial gradients at time t=tn
|
||||
|
||||
Args:
|
||||
dim (int): dimension of the final output
|
||||
train (bool): True for train
|
||||
num_hidden list(int): number of hidden layers
|
||||
"""
|
||||
def __init__(self, dim, num_hiddens):
|
||||
super(FeedForwardSubNet, self).__init__()
|
||||
self.dim = dim
|
||||
self.num_hiddens = num_hiddens
|
||||
bn_layers = [nn.BatchNorm1d(c, momentum=0.99, eps=1e-6, beta_init='normal', gamma_init='uniform')
|
||||
for c in [dim] + num_hiddens + [dim]]
|
||||
self.bns = nn.CellList(bn_layers)
|
||||
dense_layers = [nn.Dense(dim, num_hiddens[0], has_bias=False, activation=None)]
|
||||
dense_layers = dense_layers + [nn.Dense(num_hiddens[i], num_hiddens[i + 1], has_bias=False, activation=None)
|
||||
for i in range(len(num_hiddens) - 1)]
|
||||
# final output should be gradient of size dim
|
||||
dense_layers.append(nn.Dense(num_hiddens[-1], dim, activation=None))
|
||||
self.denses = nn.CellList(dense_layers)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
"""structure: bn -> (dense -> bn -> relu) * len(num_hiddens) -> dense -> bn"""
|
||||
x = self.bns[0](x)
|
||||
hiddens_length = len(self.num_hiddens)
|
||||
for i in range(hiddens_length):
|
||||
x = self.denses[i](x)
|
||||
x = self.bns[i+1](x)
|
||||
x = self.relu(x)
|
||||
x = self.denses[hiddens_length](x)
|
||||
x = self.bns[hiddens_length + 1](x)
|
||||
return x
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
"""Loss function for DeepBSDE"""
|
||||
def __init__(self, net):
|
||||
super(WithLossCell, self).__init__()
|
||||
self.net = net
|
||||
self.terminal_condition = net.bsde.terminal_condition
|
||||
self.total_time = net.bsde.total_time
|
||||
self.sum = P.ReduceSum()
|
||||
self.delta_clip = 50.0
|
||||
self.selete = P.Select()
|
||||
|
||||
def construct(self, dw, x):
|
||||
y_terminal = self.net(dw, x)
|
||||
delta = y_terminal - self.terminal_condition(self.total_time, x[:, :, -1])
|
||||
# use linear approximation outside the clipped range
|
||||
abs_delta = P.Abs()(delta)
|
||||
loss = self.sum(self.selete(abs_delta < self.delta_clip,
|
||||
P.Square()(delta),
|
||||
2 * self.delta_clip * abs_delta - self.delta_clip * self.delta_clip))
|
||||
return loss
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DeepBSDE train script"""
|
||||
import os
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import context, Tensor, Model
|
||||
from mindspore import nn
|
||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor
|
||||
from src.net import DeepBSDE, WithLossCell
|
||||
from src.config import config
|
||||
from src.equation import get_bsde, create_dataset
|
||||
from src.eval_utils import EvalCallBack
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if not os.path.exists(config.log_dir):
|
||||
os.mkdir(config.log_dir)
|
||||
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_{}.ckpt".format(config.eqn_name, "{}"))
|
||||
bsde = get_bsde(config)
|
||||
dataset = create_dataset(bsde)
|
||||
print('Begin to solve', config.eqn_name)
|
||||
net = DeepBSDE(config, bsde)
|
||||
net_with_loss = WithLossCell(net)
|
||||
config.lr_boundaries.append(config.num_iterations)
|
||||
lr = Tensor(piecewise_constant_lr(config.lr_boundaries, config.lr_values), dtype=mstype.float32)
|
||||
opt = nn.Adam(net.trainable_params(), lr)
|
||||
model = Model(net_with_loss, optimizer=opt)
|
||||
eval_param = {"model": net_with_loss, "valid_data": bsde.sample(config.valid_size)}
|
||||
cb = [LossMonitor(), TimeMonitor(), EvalCallBack(eval_param, config.ckpt_path, config.logging_frequency)]
|
||||
epoch = dataset.get_dataset_size() // config.logging_frequency
|
||||
model.train(epoch, dataset, callbacks=cb, sink_size=config.logging_frequency)
|
Loading…
Reference in New Issue