add DeepBSDE

This commit is contained in:
zhaoting 2021-07-01 10:12:02 +08:00
parent 7dffa5096c
commit 25e3a360d3
13 changed files with 841 additions and 0 deletions

View File

@ -0,0 +1,182 @@
# Contents
- [Contents](#contents)
- [DeepBSDE Description](#DeepBSDE-description)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#Model-Description)
- [Evaluation Performance](#Evaluation-Performance)
- [Inference Performance](#Inference-Performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [DeepBSDE Description](#contents)
DeepBSDE is a power of deep neural networks by developing a strategy for solving a large class of high-dimensional nonlinear PDEs using deep learning. The class of PDEs that we deal with is (nonlinear) parabolic PDEs.
[paper](https:#www.pnas.org/content/115/34/8505): Han J , Arnulf J , Weinan E . Solving high-dimensional partial differential equations using deep learning[J]. Proceedings of the National Academy of Sciences, 2018:201718942-.
## [HJB equation](#Contents)
HamiltonJacobiBellman Equation which is the term curse of dimensionality was first used explicitly by Richard Bellman in the context of dynamic programming, which has now become the cornerstone in many areas such as economics, behavioral science, computer science, and even biology, where intelligent decision making is the main issue.
# [Environment Requirements](#contents)
- Hardware(GPU)
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https:#www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https:#www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https:#www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
```shell
# Running training example
export CUDA_VISIBLE_DEVICES=0
python train.py --config_path=./config/HJBLQ_config.yaml
OR
bash ./scripts/run_train.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)
# Running evaluation example
python eval.py --config_path=./config/HJBLQ_config.yaml
OR
bash ./scripts/run_eval.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
.
├── config
│ └── HJBLQ_config.yaml # default config for HJB equation.
├── src
│ ├── config.py # config parse script.
│ ├── equation.py # equation definition and dataset helper.
│ ├── eval_utils.py # evaluation callback and evaluation utils.
│ └── net.py # DeepBSDE network structure.
├── eval.py # evaluation API entry.
├── export.py # export models API entry.
├── README_CN.md
├── README.md
├── requirements.txt # requirements of third party package.
└── train.py # train API entry.
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in `CONFIG_YAML`
- config for HBJ
```python
# eqn config
eqn_name: "HJBLQ" # Equation function name.
total_time: 1.0 # The total time of equation function.
dim: 100 # Hidden layer dims.
num_time_interval: 20 # Number of interval times.
# net config
y_init_range: [0, 1] # The y_init random initialization range.
num_hiddens: [110, 110] # A list of hidden layer's filter number.
lr_values: [0.01, 0.01] # lr_values of piecewise_constant_lr.
lr_boundaries: [1000] # lr_boundaries of piecewise_constant_lr.
num_iterations: 2000 # Iterations numbers.
batch_size: 64 # batch_size when training.
valid_size: 256 # batch_size when evaluation.
logging_frequency: 100 # logging and evaluation callback frequency.
# other config
device_target: "GPU" # Device where the code will be implemented. Optional values is GPU.
log_dir: "./logs" # The path of log saving.
file_format: "MINDIR" # Export model type.
```
For more configuration details, please refer the yaml file `./config/HJBLQ_config.yaml`.
## [Training Process](#contents)
- Running on GPU
```bash
python train.py --config_path=./config/HJBLQ_config.yaml > train.log 2>&1 &
```
- The python command above will run in the background, you can view the results through the file `train.log`
The loss value can be achieved as follows:
```log
epoch: 1 step: 100, loss is 245.3738
epoch time: 26883.370 ms, per step time: 268.834 ms
total step: 100, eval loss: 1179.300, Y0: 1.400, elapsed time: 34
epcoh: 2 step: 100, loss is 149.6593
epoch time: 3184.401 ms, per step time: 32.877 ms
total step: 200, eval loss: 659.457, Y0: 1.693, elapsed time: 37
...
```
After training, you'll get the last checkpoint file under the folder `log_dir` in config.
## [Evaluation Process](#contents)
- Evaluation on GPU
Before running the command below, please check the checkpoint path used for evaluation. Such as `./log/deepbsde_HJBLQ_end.ckpt`
```bash
python eval.py --config_path=./config/HJBLQ_config.yaml > eval.log 2>&1 &
```
The above python command will run in the background. You can view the results through the file "eval.log". The error of evaluation is as follows:
```log
eval loss: 5.146923065185527, Y0: 4.59813117980957
```
# [Model Description](#contents)
## [Evaluation Performance](#contents)
| Parameters | GPU |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | DeepBSDE |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 7/5/2021 (month/day/year) |
| MindSpore Version | 1.2.0 | |
| Training Parameters | step=2000, see `./config/HJBLQ_config.yaml` for details |
| Optimizer | Adam | |
| Loss | 2.11 |
| Speed | 32ms/step |
| Total time | 3 min |
| Parameters | 650K |
| Checkpoint for Fine tuning | 7.8M (.ckpt file) |
## [Inference Performance](#contents)
| Parameters | GPU |
| ----------------- | -------------------------------------------- |
| Model Version | DeepBSDE |
| Resource | NV SMX2 V100-32G |
| uploaded Date | 7/5/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| outputs | eval loss & Y0 |
| Y0 | Y0: 4.59 |
# [Description of Random Situation](#contents)
We use random in equation.pywhich can be set seed to fixed randomness.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https:#gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,41 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
# eqn config
eqn_name: "HJBLQ"
total_time: 1.0
dim: 100
num_time_interval: 20
# net config
y_init_range: [0, 1]
num_hiddens: [110, 110]
lr_values: [0.01, 0.01]
lr_boundaries: [1000]
num_iterations: 2000
batch_size: 64
valid_size: 256
logging_frequency: 100
# other config
device_target: "GPU"
log_dir: "./logs"
file_format: "MINDIR"
---
# Help description for each configuration
eqn_name: "Equation function name."
total_time: "The total time of equation function."
dim: "Hidden layer dims."
num_time_interval: "Number of interval times."
y_init_range: "The y_init random initialization range."
num_hiddens: "A list of hidden layer's filter number."
lr_values: "lr_values of piecewise_constant_lr."
lr_boundaries: "lr_boundaries of piecewise_constant_lr."
num_iterations: "Iterations numbers."
batch_size: "batch_size when training."
valid_size: "batch_size when evaluation."
logging_frequency: "logging and evaluation callback frequency."
device_target: "Device where the code will be implemented. Optional values is GPU."
log_dir: "The path of log saving."
file_format: "Export model type."

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeepBSDE evaluation script"""
import os
from mindspore import context, load_checkpoint
from src.net import DeepBSDE, WithLossCell
from src.config import config
from src.equation import get_bsde
from src.eval_utils import apply_eval
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_end.ckpt".format(config.eqn_name))
bsde = get_bsde(config)
print('Begin to solve', config.eqn_name)
net = DeepBSDE(config, bsde)
net_with_loss = WithLossCell(net)
load_checkpoint(config.ckpt_path, net=net_with_loss)
eval_param = {"model": net_with_loss, "valid_data": bsde.sample(config.valid_size)}
loss, y_init = apply_eval(eval_param)
print("eval loss: {}, Y0: {}".format(loss, y_init))

View File

@ -0,0 +1,30 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeepBSDE export model script"""
import os
from mindspore import context, load_checkpoint, export, Tensor
from src.net import DeepBSDE
from src.config import config
from src.equation import get_bsde
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_end.ckpt".format(config.eqn_name))
bsde = get_bsde(config)
print('Begin to solve', config.eqn_name)
net = DeepBSDE(config, bsde)
load_checkpoint(config.ckpt_path, net=net)
dw, x = bsde.sample(config.valid_size)
export(net, Tensor(dw), Tensor(x), file_name="deepbsde_{}".format(config.eqn_name), file_format=config.file_format)

View File

@ -0,0 +1,2 @@
scipy >= 1.5.2
PyYAML

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 1 ] && [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash ./scripts/run_eval.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)"
echo "for example: bash ./scripts/run_eval.sh ./config/HJBLQ_config.yaml 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0
if [ $# == 2 ];
then
export DEVICE_ID=$2
fi
config_yaml=$(get_real_path $1)
nohup python ${PROJECT_DIR}/../eval.py --config_path=$config_yaml > eval.log 2>&1 &

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 1 ] && [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash ./scripts/run_train.sh [CONFIG_YAML] [DEVICE_ID](option, default is 0)"
echo "for example: bash ./scripts/run_train.sh ./config/HJBLQ_config.yaml 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0
if [ $# == 2 ];
then
export DEVICE_ID=$2
fi
config_yaml=$(get_real_path $1)
nohup python ${PROJECT_DIR}/../train.py --config_path=$config_yaml > train.log 2>&1 &

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../config/HJBLQ_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""equations for different PDE function."""
import numpy as np
from scipy.stats import multivariate_normal as normal
from mindspore import ops as P
from mindspore import nn
import mindspore.dataset as ds
class Equation():
"""Base class for defining PDE related function."""
def __init__(self, cfg):
self.dim = cfg.dim
self.total_time = cfg.total_time
self.steps = cfg.num_iterations
self.num_time_interval = cfg.num_time_interval
self.delta_t = self.total_time / self.num_time_interval
self.sqrt_delta_t = np.sqrt(self.delta_t)
self.y_init = None
self.num_sample = cfg.batch_size
self.generator = P.Identity()
self.terminal_condition = P.Identity()
def sample(self, num_sample):
"""Sample forward SDE."""
raise NotImplementedError
def __getitem__(self, index):
return self.sample(self.num_sample)
@property
def column_names(self):
return ["dw", "x"]
def __len__(self):
return self.steps
class HJBLQ(Equation):
"""HJB equation in PNAS paper doi.org/10.1073/pnas.1718942115"""
def __init__(self, cfg):
super(HJBLQ, self).__init__(cfg)
self.x_init = np.zeros(self.dim)
self.sigma = np.sqrt(2.0)
self.generator = HJBLQGenerator(1.0)
self.terminal_condition = HJBLQTerminalCondition()
def sample(self, num_sample):
# draw random samples from a multivariate normal distribution
dw_sample = normal.rvs(size=[num_sample,
self.dim,
self.num_time_interval]) * self.sqrt_delta_t # num_sample, dim, num_time_interval
x_sample = np.zeros([num_sample, self.dim, self.num_time_interval + 1])
for i in range(self.num_time_interval):
x_sample[:, :, i + 1] = x_sample[:, :, i] + self.sigma * dw_sample[:, :, i]
return dw_sample.astype(np.float32), x_sample.astype(np.float32)
class HJBLQGenerator(nn.Cell):
"""Generator function for HJBLQ"""
def __init__(self, lambd):
super(HJBLQGenerator, self).__init__()
self.lambd = lambd
self.sum = P.ReduceSum(keep_dims=True)
self.square = P.Square()
def construct(self, t, x, y, z):
res = -self.lambd * self.sum(self.square(z), 1)
return res
class HJBLQTerminalCondition(nn.Cell):
"""Terminal condition for HJBLQ"""
def __init__(self):
super(HJBLQTerminalCondition, self).__init__()
self.sum = P.ReduceSum(keep_dims=True)
self.square = P.Square()
def construct(self, t, x):
res = P.log((1 + self.sum(self.square(x), 1)) / 2)
return res
def get_bsde(cfg):
bsde_dict = {"HJBLQ": HJBLQ(cfg)}
return bsde_dict[cfg.eqn_name.upper()]
def create_dataset(bsde):
"""Get generator dataset when training."""
dataset = ds.GeneratorDataset(bsde, bsde.column_names)
return dataset

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import time
from mindspore import Tensor, save_checkpoint
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_param_dict (dict): evaluation parameters' configure dict.
ckpt_path (str): save checkpoint path format, eg: "./logs/deepbsde_hjb_{}.ckpt".
interval (int): run evaluation interval, default is 1.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_param_dict, ckpt_path, interval=1):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = apply_eval
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.best_res = 0
self.best_epoch = 0
self.ckpt_path = ckpt_path
self.start_time = time.time()
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_step = cb_params.cur_step_num
if cur_step % self.interval == 0:
loss, y_init = self.eval_function(self.eval_param_dict)
elapsed_time = time.time() - self.start_time
print("total step: {:4d}, eval loss: {:5.3f}, Y0: {:5.3f}, elapsed time: {:3.0f}".format(
cur_step, loss, y_init, elapsed_time))
def end(self, run_context):
cb_params = run_context.original_args()
save_checkpoint(cb_params.train_network, self.ckpt_path.format("end"))
def apply_eval(eval_param):
eval_model = eval_param["model"]
dw, x = eval_param["valid_data"]
eval_model.set_train(False)
loss = eval_model(Tensor(dw), Tensor(x)).asnumpy()
y_init = eval_model.net.y_init.asnumpy()[0]
return loss, y_init

View File

@ -0,0 +1,119 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the network structure of DeepBSDE"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import ops as P
from mindspore import Tensor, Parameter
class DeepBSDE(nn.Cell):
"""
The network structure of DeepBSDE.
Args:
cfg: configure settings.
bsde(Cell): equation function
"""
def __init__(self, cfg, bsde):
super(DeepBSDE, self).__init__()
self.bsde = bsde
self.delta_t = bsde.delta_t
self.num_time_interval = bsde.num_time_interval
self.dim = bsde.dim
self.time_stamp = Tensor(np.arange(0, cfg.num_time_interval) * bsde.delta_t)
self.y_init = Parameter(np.random.uniform(low=cfg.y_init_range[0],
high=cfg.y_init_range[1],
size=[1]).astype(np.float32))
self.z_init = Parameter(np.random.uniform(low=-0.1, high=0.1, size=[1, cfg.dim]).astype(np.float32))
self.subnet = nn.CellList([FeedForwardSubNet(cfg.dim, cfg.num_hiddens)
for _ in range(bsde.num_time_interval-1)])
self.generator = bsde.generator
self.matmul = P.MatMul()
self.sum = P.ReduceSum(keep_dims=True)
def construct(self, dw, x):
"""repeat FeedForwardSubNet (num_time_interval - 1) times."""
all_one_vec = P.Ones()((P.shape(dw)[0], 1), mstype.float32)
y = all_one_vec * self.y_init
z = self.matmul(all_one_vec, self.z_init)
for t in range(0, self.num_time_interval - 1):
y = y - self.delta_t * (self.generator(self.time_stamp[t], x[:, :, t], y, z)) + self.sum(z * dw[:, :, t], 1)
z = self.subnet[t](x[:, :, t + 1]) / self.dim
# terminal time
y = y - self.delta_t * self.generator(self.time_stamp[-1], x[:, :, -2], y, z) + self.sum(z * dw[:, :, -1], 1)
return y
class FeedForwardSubNet(nn.Cell):
"""
Subnet to fit the spatial gradients at time t=tn
Args:
dim (int): dimension of the final output
train (bool): True for train
num_hidden list(int): number of hidden layers
"""
def __init__(self, dim, num_hiddens):
super(FeedForwardSubNet, self).__init__()
self.dim = dim
self.num_hiddens = num_hiddens
bn_layers = [nn.BatchNorm1d(c, momentum=0.99, eps=1e-6, beta_init='normal', gamma_init='uniform')
for c in [dim] + num_hiddens + [dim]]
self.bns = nn.CellList(bn_layers)
dense_layers = [nn.Dense(dim, num_hiddens[0], has_bias=False, activation=None)]
dense_layers = dense_layers + [nn.Dense(num_hiddens[i], num_hiddens[i + 1], has_bias=False, activation=None)
for i in range(len(num_hiddens) - 1)]
# final output should be gradient of size dim
dense_layers.append(nn.Dense(num_hiddens[-1], dim, activation=None))
self.denses = nn.CellList(dense_layers)
self.relu = nn.ReLU()
def construct(self, x):
"""structure: bn -> (dense -> bn -> relu) * len(num_hiddens) -> dense -> bn"""
x = self.bns[0](x)
hiddens_length = len(self.num_hiddens)
for i in range(hiddens_length):
x = self.denses[i](x)
x = self.bns[i+1](x)
x = self.relu(x)
x = self.denses[hiddens_length](x)
x = self.bns[hiddens_length + 1](x)
return x
class WithLossCell(nn.Cell):
"""Loss function for DeepBSDE"""
def __init__(self, net):
super(WithLossCell, self).__init__()
self.net = net
self.terminal_condition = net.bsde.terminal_condition
self.total_time = net.bsde.total_time
self.sum = P.ReduceSum()
self.delta_clip = 50.0
self.selete = P.Select()
def construct(self, dw, x):
y_terminal = self.net(dw, x)
delta = y_terminal - self.terminal_condition(self.total_time, x[:, :, -1])
# use linear approximation outside the clipped range
abs_delta = P.Abs()(delta)
loss = self.sum(self.selete(abs_delta < self.delta_clip,
P.Square()(delta),
2 * self.delta_clip * abs_delta - self.delta_clip * self.delta_clip))
return loss

View File

@ -0,0 +1,44 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeepBSDE train script"""
import os
from mindspore import dtype as mstype
from mindspore import context, Tensor, Model
from mindspore import nn
from mindspore.nn.dynamic_lr import piecewise_constant_lr
from mindspore.train.callback import TimeMonitor, LossMonitor
from src.net import DeepBSDE, WithLossCell
from src.config import config
from src.equation import get_bsde, create_dataset
from src.eval_utils import EvalCallBack
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if not os.path.exists(config.log_dir):
os.mkdir(config.log_dir)
config.ckpt_path = os.path.join(config.log_dir, "deepbsde_{}_{}.ckpt".format(config.eqn_name, "{}"))
bsde = get_bsde(config)
dataset = create_dataset(bsde)
print('Begin to solve', config.eqn_name)
net = DeepBSDE(config, bsde)
net_with_loss = WithLossCell(net)
config.lr_boundaries.append(config.num_iterations)
lr = Tensor(piecewise_constant_lr(config.lr_boundaries, config.lr_values), dtype=mstype.float32)
opt = nn.Adam(net.trainable_params(), lr)
model = Model(net_with_loss, optimizer=opt)
eval_param = {"model": net_with_loss, "valid_data": bsde.sample(config.valid_size)}
cb = [LossMonitor(), TimeMonitor(), EvalCallBack(eval_param, config.ckpt_path, config.logging_frequency)]
epoch = dataset.get_dataset_size() // config.logging_frequency
model.train(epoch, dataset, callbacks=cb, sink_size=config.logging_frequency)