PINNs (Schrodinger)

This commit is contained in:
yuyiyang_3418 2021-05-18 19:20:03 +08:00
parent 84837b0bc9
commit bbeaebaf0b
15 changed files with 1183 additions and 1 deletions

View File

@ -0,0 +1,218 @@
# Contents
[查看中文](./README_CN.md)
- [PINNs Description](#PINNs-Description)
- [Model Architecture](#model-architecture)
- [Schrodinger equation](#Schrodinger-equation)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Evaluation of Schrodinger equation scenario](#Evaluation-of-Schrodinger-equation-scenario)
- [Inference Performance](#evaluation-performance)
- [Inference of Schrodinger equation scenario](#Inference-of-Schrodinger-equation-scenario)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [PINNs Description](#contents)
PINNs (Physics Information Neural Networks) is a neural network proposed in 2019. PINNs network provides a new approach for solving partial differential equations with neural network. Partial differential equations are often used in the modeling of physical, biological and engineering systems. The characteristics of such systems have significantly difference from most problems in machine learning: (1) the cost of data acquisition is high, and the amount of data is usually small;2) a large amount of priori knowledge, such as previous research result like physical laws, are hard to be utilized by machine learning systems.
In PINNs, firstly the prior knowledge in the form of partial differential equation is introduced as the regularization term of the network through proper construction of the Pinns network. Then, by utilizing the prior knowledge in PINNs, the network can train very good results with very little data.
[paper](https://www.sciencedirect.com/science/article/pii/S0021999118307125)Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations."*Journal of Computational Physics*. 2019 (378): 686-707.
# [Model Architecture](#contents)
Pinns is a new framework of constructing neural network for solving partial differential equations. The specific model structure will change according to the partial differential equations. The network structure of each application scenario of PINNs implemented in MindSpore is as follows:
## [Schrodinger equation](#Contents)
The PINNs of the Schrodinger equation can be divided into two parts. First, a neural network composed of five fully connected layers is used to fit the wave function to be solved (i.e., the solution of the Schrodinger equation in the quantum mechanics system described by the data set). The neural network has two outputs, which represent the real part and the imaginary part of the wave function respectively. Then, the two outputs are followed by some derivative operations. The Schrodinger equation can be expressed by properly combining these derivative results, and act as a constraint term of the neural network. The outputs of the whole network are the real part, imaginary part and some related partial derivatives of the wave function.
# [Dataset](#contents)
Note that you can run the scripts based on the dataset once you have downloaded the dataset from the corresponding link to the data storage path (default path is '/PINNs/Data/') . In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [NLS](https://github.com/maziarraissi/PINNs/tree/master/main/Data), can refer to [paper](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
- Dataset size546KB51456 points sampled from the wave function of a one-dimensional quantum mechanics system with periodic boundary conditions.
- Train150 data points
- TestAll 51456 data points of the dataset.
- Data formatmat files
- NoteThis dataset is used in the Schrodinger equation scenario. Data will be processed in src/Schrodinger/dataset.py
# [Features](#contents)
## [Mixed Precision](#Contents)
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareGPU
- Prepare hardware environment with GPU processor.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- Schrodinger equation scenario running on GPU
```shell
# Running training example
export CUDA_VISIBLE_DEVICES=0
python train.py --scenario=Schrodinger --datapath=./Data/mat > train.log
OR
bash /scripts/run_standalone_Schrodinger_train.sh Schrodinger
# Running evaluation example
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
OR
bash /scriptsrun_standalone_Schrodinger_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```text
├── model_zoo
├── README.md // descriptions about all the models
├── PINNs
├── README.md // descriptions about PINNs
├── scripts
│ ├──run_standalone_Schrodinger_train.sh // shell script for Schrodinger equation scenario training on GPU
| ├──run_standalone_Schrodinger_eval.sh // shell script for Schrodinger equation scenario evaluation on GPU
├── src
| ├──Schrodinger //Schrodinger equation scenario
│ | ├──dataset.py // creating dataset
│ | ├──net.py // PINNs (Schrodinger) architecture
│ ├──config.py // parameter configuration
├── train.py // training script (Schrodinger)
├── eval.py // evaluation script (Schrodinger)
├── export.py // export checkpoint files into mindir (Schrodinger) ├── ├── requirements // additional packages required to run PINNs networks
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- config for Schrodinger equation scenario
```python
'epoch':50000 # number of epochs in training
'lr':0.0001 # learning rate
'N0':50 # number of sampling points of the training set at the initial condition. For the NLS dataset, 0<N0<=256
'Nb':50 # number of sampling points of the training set at the boundary condition. For the NLS dataset, 0<N0<=201
'Nf':20000 # number of collocations points used to calculate the constraint of Schrodinger equation in training. For the NLS dataset, 0<Nf<=51456
'num_neuron':100 # number of neurons in fully connected hidden layer of PINNs network for Schrodinger equation
'seed':2 # random seed
'path':'./Data/NLS.mat' # data set storage path
'ck_path':'./ckpoints/' # path to save checkpoint files (.ckpt)
```
For more configuration details, please refer the script `config.py`.
## [Training Process](#contents)
- Running Schrodinger equation scenario on GPU
```bash
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log 2>&1 &
```
- The python command above will run in the background, you can view the results through the file `train.log`
The loss value can be achieved as follows:
```bash
# grep "loss is " train.log
epoch: 1 step: 1, loss is 1.3523688
epoch time: 7519.499 ms, per step time: 7519.499 ms
epcoh: 2 step: 1, loss is 1.2859955
epoch time: 429.470 ms
...
```
After training, you'll get some checkpoint files under the folder `./ckpoints/` by default.
## [Evaluation Process](#contents)
- evaluation of Schrodinger equation scenario when running on GPU
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., “./ckpt/checkpoint_PINNs_Schrodinger-50000_1.ckpt”。
```bash
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
```
The above python command will run in the background. You can view the results through the file "eval.log". The error of evaluation is as follows:
```bash
# grep "accuracy:" eval.log
evaluation error is: 0.01207
```
# [Model Description](#contents)
## [Performance](#contents)
### [Evaluation Performance](#contents)
#### [Evaluation of Schrodinger equation scenario](#contents)
| Parameters | GPU |
| -------------------------- | ------------------------------------------------------------ |
| Model Version | PINNs (Schrodinger) |
| Resource | NV Tesla V100-32G |
| uploaded Date | 5/20/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | NLS |
| Training Parameters | epoch=50000, lr=0.0001. see src/config.py for details |
| Optimizer | Adam |
| Loss Function | src/Schrodinger/loss.py |
| outputs | the wave function (real part and imaginary part)first order derivative of the wave function to the coordinates (real part and imaginary part)the fitting of the Schrodinger equation (real part and imaginary part) |
| Loss | 0.00009928 |
| Speed | 456ms/step |
| Total time | 6.3344 hours |
| Parameters | 32K |
| Checkpoint for Fine tuning | 363K (.ckpt file) |
### [Inference Performance](#contents)
#### [Inference of Schrodinger equation scenario](#contents)
| Parameters | GPU |
| ----------------- | -------------------------------------------- |
| Model Version | PINNs (Schrodinger) |
| Resource | NV Tesla V100-32G |
| uploaded Date | 5/20/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | NLS |
| outputs | real part and imaginary of the wave function |
| mean square error | 0.01323 |
# [Description of Random Situation](#contents)
We use random seed in train.pywhich can be reset in src/config.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,219 @@
# 目录
[View English](./README.md)
- [目录](#目录)
- [PINNs描述](#PINNs描述)
- [模型架构](#模型架构)
- [Schrodinger方程](#Schrodinger方程)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [Schrodinger方程场景评估](#Schrodinger方程场景评估)
- [推理性能](#推理性能)
- [Schrodinger方程场景推理](#Schrodinger方程场景推理)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
# [PINNs描述](#目录)
PINNs (Physics-informed neural networks)是2019年提出的神经网络。PINNs网络提供了一种全新的用神经网络求解偏微分方程的思路。对现实的物理、生物、工程等系统建模时常常会用到偏微分方程。而此类问题的特征与机器学习中遇到的大多数问题有两点显著不同(1)获取数据的成本较高,数据量通常较小;(2)存在大量前人对于此类问题的研究成果作为先验知识而无法被机器学习系统利用例如各种物理定律等。PINNs网络首先通过适当的构造将偏微分方程形式的先验知识作为网络的正则化约束引入进而通过利用这些先验知识强大的约束作用使得网络能够用很少的数据就训练出很好的结果。PINNs网络在量子力学等场景中经过了成功的验证能够用很少的数据成功训练网络并对相应的物理系统进行建模。
[论文](https://www.sciencedirect.com/science/article/pii/S0021999118307125)Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations."*Journal of Computational Physics*. 2019 (378): 686-707.
# [模型架构](#目录)
PINNs是针对偏微分方程问题构造神经网络的思路具体的模型结构会根据所要求解的偏微分方程而有相应的变化。在MindSpore中实现的PINNs各应用场景的网络结构如下
## [Schrodinger方程](#目录)
针对Schrodinger方程的PINNs分为两部分首先是一个由5个全连接层组成的神经网络用来拟合待求解的波函数(即薛定谔方程在数据集所描述的量子力学系统下的解)。该神经网络有2个输出分别表示波函数的实部和虚部。之后在这两个输出后面接上一些求导的操作将这些求导的结果适当的组合起来就可以表示Schrodinger方程作为神经网络的约束项。将波函数的实部、虚部以及一些相关的偏导数作为整个网络的输出。
# [数据集](#目录)
从数据集相应的链接中下载数据集至指定目录(默认'/PINNs/Data/')后可运行相关脚本。文档的后面会介绍如何使用相关脚本。
使用的数据集:[NLS](https://github.com/maziarraissi/PINNs/tree/master/main/Data), 可参照[论文](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
- 数据集大小546KB对一维周期性边界量子力学系统波函数的51456个采样点。
- 训练集150个点
- 测试集整个数据集的全部51456个采样点
- 数据格式mat文件
- 注该数据集在Schrodinger方程场景中使用。数据将在src/Schrodinger/dataset.py中处理。
# [特性](#目录)
## [混合精度](#目录)
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# [环境要求](#目录)
- 硬件GPU
- 使用GPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [快速入门](#目录)
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- GPU处理器环境运行Schrodinger方程场景
```shell
# 运行训练示例
export CUDA_VISIBLE_DEVICES=0
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log
OR
bash /scripts/run_standalone_Schrodinger_train.sh [DATASET_PATH]
# 运行评估示例
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
OR
bash /scriptsrun_standalone_Schrodinger_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
```
# [脚本说明](#目录)
## [脚本及样例代码](#目录)
```text
├── model_zoo
├── README.md // 所有模型相关说明
├── PINNs
├── README.md // PINNs相关说明
├── scripts
│ ├──run_standalone_Schrodinger_train.sh // Schrodinger方程GPU训练的shell脚本
| ├──run_standalone_Schrodinger_eval.sh // Schrodinger方程GPU评估的shell脚本
├── src
| ├──Schrodinger //Schrodinger方程场景
│ | ├──dataset.py //创建数据集
│ | ├──net.py // PINNs (Schrodinger) 架构
│ ├──config.py // 参数配置
├── train.py // 训练脚本 (Schrodinger)
├── eval.py // 评估脚本 (Schrodinger)
├── export.py // 将checkpoint文件导出为mindir
├── requirements // 运行PINNs网络额外需要的包
```
## [脚本参数](#目录)
在config.py中可以同时配置训练参数和评估参数。
- 配置Schrodinger方程场景。
```python
'epoch':50000 #训练轮次
'lr':0.0001 #学习率
'N0':50 #训练集在初始条件处的采样点数量,对于NLS数据集0<N0<=256
'Nb':50 #训练集在边界条件处的采样点数量,对于NLS数据集0<Nb<=201
'Nf':20000 #训练时用于计算Schrodinger方程约束的配点数
'num_neuron':100 #PINNs网络全连接隐藏层的神经元数量
'seed':2 #随机种子
'path':'./Data/NLS.mat' #数据集存储路径
'ck_path':'./ckpoints/' #保存checkpoint文件(.ckpt)的路径
```
更多配置细节请参考脚本`config.py`。
## [训练过程](#目录)
- GPU处理器环境运行Schrodinger方程场景
```bash
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log 2>&1 &
```
- 以上python命令将在后台运行。您可以通过train.log文件查看结果。
可以采用以下方式达到损失值:
```bash
# grep "loss is " train.log
epoch: 1 step: 1, loss is 1.3523688
epoch time: 7519.499 ms, per step time: 7519.499 ms
epcoh: 2 step: 1, loss is 1.2859955
epoch time: 429.470 ms
...
```
训练结束后,您可在默认`./ckpoints/`脚本文件夹下找到检查点文件。
## [评估过程](#目录)
- 在GPU处理器环境运行Schrodinger方程场景
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“./ckpt/checkpoint_PINNs_Schrodinger-50000_1.ckpt”。
```bash
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
```
上述python命令将在后台运行您可以通过eval.log文件查看结果。测试误差如下
```bash
# grep "accuracy:" eval.log
evaluation error is: 0.01207
```
# [模型描述](#目录)
## [性能](#目录)
### [评估性能](#目录)
#### [Schrodinger方程场景评估](#目录)
| 参数 | GPU |
| -------------------------- | ---------------------- |
| 模型版本 | PINNs (Schrodinger) |
| 资源 | NV Tesla V100-32G |
| 上传日期 | 2021-5-20 |
| MindSpore版本 | 1.2.0 |
| 数据集 | NLS |
| 训练参数 | epoch=50000, lr=0.0001. 详见src/config.py |
| 优化器 | Adam |
| 损失函数 | src/Schrodinger/loss.py |
| 输出 | 波函数(实部,虚部),波函数对坐标的一阶导(实部,虚部),对薛定谔方程的拟合(实部,虚部) |
| 损失 | 0.00009928 |
| 速度 | 456毫秒/步 |
| 总时长 | 6.3344 小时 |
| 参数 | 32K |
| 微调检查点 | 363K (.ckpt文件) |
### [推理性能](#目录)
#### [Schrodinger方程场景推理](#目录)
| 参数 | GPU |
| ------------------- | --------------------------- |
| 模型版本 | PINNs (Schrodinger) |
| 资源 | NV Tesla V100-32G |
| 上传日期 | 2021-5-20 |
| MindSpore 版本 | 1.2.0 |
| 数据集 | NLS |
| 输出 | 波函数的实部与虚部 |
| 均方误差 | 0.01323 |
# [随机情况说明](#目录)
在train.py中的使用了随机种子可在src/config.py中修改。
# [ModelZoo主页](#目录)
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,81 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import argparse
import numpy as np
from mindspore import Tensor, context
from mindspore import load_checkpoint, load_param_into_net
import mindspore.common.dtype as mstype
from src import config
from src.Schrodinger.dataset import get_eval_data
from src.Schrodinger.net import PINNs
def eval_PINNs_sch(ckpoint_name, num_neuron=100, path='./Data/NLS.mat'):
"""
Evaluation of PINNs for Schrodinger equation scenario.
Args:
ckpoint_name (str): model checkpoint file name
num_neuron (int): number of neurons for fully connected layer in the network
path (str): path of the dataset for Schrodinger equation
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
lb = np.array([-5.0, 0.0])
ub = np.array([5.0, np.pi/2])
n = PINNs(layers, lb, ub)
param_dict = load_checkpoint(ckpoint_name)
load_param_into_net(n, param_dict)
X_star, _, _, h_star = get_eval_data(path)
X_tensor = Tensor(X_star, mstype.float32)
pred = n(X_tensor)
u_pred = pred[0].asnumpy()
v_pred = pred[1].asnumpy()
h_pred = np.sqrt(u_pred**2 + v_pred**2)
error_h = np.linalg.norm(h_star-h_pred, 2)/np.linalg.norm(h_star, 2)
print(f'evaluation error is: {error_h}')
return error_h
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate PINNs for Schrodinger equation scenario')
parser.add_argument('ck_file', type=str, help='model checkpoint(ckpt) filename')
#only support 'Schrodinger' for now
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
parser.add_argument('--datapath', type=str, help='path for dataset', default='')
args_opt = parser.parse_args()
f_name = args_opt.ck_file
pinns_scenario = args_opt.scenario
data_path = args_opt.datapath
if pinns_scenario == 'Schrodinger':
conf = config.config_Sch
hidden_size = conf['num_neuron']
if data_path == '':
dataset_path = conf['path']
else:
dataset_path = data_path
mse_error = eval_PINNs_sch(f_name, hidden_size, dataset_path)
else:
print(f'{pinns_scenario} is not supported in PINNs evaluation for now')

View File

@ -0,0 +1,70 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import numpy as np
from mindspore import (Tensor, context, export, load_checkpoint,
load_param_into_net)
import mindspore.common.dtype as mstype
from src import config
from src.Schrodinger.net import PINNs
parser = argparse.ArgumentParser(description='PINNs export')
parser.add_argument('ck_file', type=str, help='model checkpoint(ckpt) filename')
parser.add_argument('file_name', type=str, help='export file name')
#only support Schrodinger' for now
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
def export_sch(conf_sch, export_format, export_name):
"""
export PINNs for Schrodinger model
Args:
conf_sch (dict): dictionary for configuration, see src/config.py for details
export_format (str): file format to export
export_name (str): name of exported file
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
num_neuron = conf_sch['num_neuron']
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
lb = np.array([-5.0, 0.0])
ub = np.array([5.0, np.pi/2])
n = PINNs(layers, lb, ub)
param_dict = load_checkpoint(ck_file)
load_param_into_net(n, param_dict)
batch_size = conf_sch['N0'] + 2*conf_sch['Nb'] +conf_sch['Nf']
inputs = Tensor(np.ones((batch_size, 2)), mstype.float32)
export(n, inputs, file_name=export_name, file_format=export_format)
if __name__ == '__main__':
args_opt = parser.parse_args()
ck_file = args_opt.ck_file
file_format = 'MINDIR'
file_name = args_opt.file_name
pinns_scenario = args_opt.scenario
conf = config.config_Sch
if pinns_scenario == 'Schrodinger':
export_sch(conf, file_format, file_name)
else:
print(f'{pinns_scenario} scenario in PINNs is not supported to export for now')

View File

@ -0,0 +1 @@
pyDOE>=0.3.8

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 2 ] && [ $# != 3 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_Schrodinger_eval.sh [CHECKPOINT] [DATASET] [DEVICE_ID](option, default is 0)"
echo "for example: bash scripts/run_standalone_Schrodinger_eval.sh ckpoints/checkpoint_PINNs_Schrodinger-50000_1.ckptData/NLS.mat 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
export DEVICE_ID=0
if [ $# == 3 ];
then
export DEVICE_ID=$3
fi
ck_path=$(get_real_path $1)
data_set_path=$(get_real_path $2)
python ${PROJECT_DIR}/../eval.py $ck_path --scenario=Schrodinger --datapath=$data_set_path > eval.log 2>&1 &

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ $# != 1 ] && [ $# != 2 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_Schrodinger_train.sh [DATASET] [DEVICE_ID](option, default is 0)"
echo "for example: bash scripts/run_standalone_Schrodinger_train.sh Data/NLS.mat 0"
echo "=============================================================================================================="
exit 1
fi
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
data_set_path=$(get_real_path $1)
export DEVICE_ID=0
if [ $# == 2 ];
then
export DEVICE_ID=$2
fi
python ${PROJECT_DIR}/../train.py --datapath $data_set_path --scenario Schrodinger > train.log 2>&1 &

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,124 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create dataset for training or evaluation"""
import mindspore.dataset as ds
import numpy as np
import scipy.io as scio
from pyDOE import lhs
class PINNs_training_set:
"""
Training set for PINNs (Schrodinger)
Args:
N0 (int): number of sampled training data points for the initial condition
Nb (int): number of sampled training data points for the boundary condition
Nf (int): number of sampled training data points for the collocation points
lb (np.array): lower bound (x, t) of domain
ub (np.array): upper bound (x, t) of domain
"""
def __init__(self, N0, Nb, Nf, lb, ub, path='./Data/NLS.mat'):
data = scio.loadmat(path)
self.N0 = N0
self.Nb = Nb
self.Nf = Nf
self.lb = lb
self.ub = ub
# load data
t = data['tt'].flatten()[:, None]
x = data['x'].flatten()[:, None]
Exact = data['uu']
Exact_u = np.real(Exact)
Exact_v = np.imag(Exact)
idx_x = np.random.choice(x.shape[0], self.N0, replace=False)
self.x0 = x[idx_x, :]
self.u0 = Exact_u[idx_x, 0:1]
self.v0 = Exact_v[idx_x, 0:1]
idx_t = np.random.choice(t.shape[0], self.Nb, replace=False)
self.tb = t[idx_t, :]
self.X_f = self.lb + (self.ub-self.lb)*lhs(2, self.Nf)
def __getitem__(self, index):
if index < self.N0: # N0 initial points
x = np.array([self.x0[index][0]], np.float32)
t = np.array([0], np.float32)
u_target = np.array(self.u0[index], np.float32)
v_target = np.array(self.v0[index], np.float32)
elif self.N0 <= index < self.N0+self.Nb: # Nb lower bound points
ind = index - self.N0
x = np.array([self.lb[0]], np.float32)
t = np.array([self.tb[ind][0]], np.float32)
u_target = np.array([self.ub[0]], np.float32)
v_target = t
elif self.N0+self.Nb <= index < self.N0+2*self.Nb: # Nb upper bound points
ind = index - self.N0 - self.Nb
x = np.array([self.ub[0]], np.float32)
t = np.array([self.tb[ind][0]], np.float32)
u_target = np.array([self.lb[0]], np.float32)
v_target = t
else: # Nf collocation points
ind = index - self.N0 - 2*self.Nb
x = np.array(self.X_f[ind, 0:1], np.float32)
t = np.array(self.X_f[ind, 1:2], np.float32)
u_target = np.array([0], np.float32)
v_target = np.array([0], np.float32)
return np.hstack((x, t)), np.hstack((u_target, v_target))
def __len__(self):
return self.N0+2*self.Nb+self.Nf
def generate_PINNs_training_set(N0, Nb, Nf, lb, ub, path='./Data/NLS.mat'):
"""
Generate training set for PINNs
Args: see class PINNs_train_set
"""
s = PINNs_training_set(N0, Nb, Nf, lb, ub, path)
dataset = ds.GeneratorDataset(source=s, column_names=['data', 'label'], shuffle=False)
dataset = dataset.batch(batch_size=len(s))
return dataset
def get_eval_data(path):
"""
Get the evaluation data for Schrodinger equation.
"""
data = scio.loadmat(path)
t = data['tt'].astype(np.float32).flatten()[:, None]
x = data['x'].astype(np.float32).flatten()[:, None]
Exact = data['uu']
Exact_u = np.real(Exact).astype(np.float32)
Exact_v = np.imag(Exact).astype(np.float32)
Exact_h = np.sqrt(Exact_u**2 + Exact_v**2)
X, T = np.meshgrid(x, t)
X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None]))
u_star = Exact_u.T.flatten()[:, None]
v_star = Exact_v.T.flatten()[:, None]
h_star = Exact_h.T.flatten()[:, None]
return X_star, u_star, v_star, h_star

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loss function for PINNs (Schrodinger)"""
from mindspore import nn, ops
import mindspore.common.dtype as mstype
class PINNs_loss(nn.Cell):
"""
Loss of the PINNs network, only works with full-batch training. Training data are organized in
the following order: initial condition points ([0:n0]), boundary condition points ([n0:(n0+2*nb)]),
collocation points ([(n0+2*nb)::])
"""
def __init__(self, n0, nb, nf, reduction='sum'):
super(PINNs_loss, self).__init__(reduction)
self.n0 = n0
self.nb = nb
self.nf = nf
self.zeros = ops.Zeros()
self.mse = nn.MSELoss(reduction='mean')
self.f_target = self.zeros((self.nf, 1), mstype.float32)
def construct(self, pred, target):
"""
pred: prediction value (u, v, ux, vx, fu, fv)
target: target[:, 0:1] = u_target, target[:, 0:2] = v_target
"""
u0_pred = pred[0][0:self.n0, 0:1]
u0 = target[0:self.n0, 0:1]
v0_pred = pred[1][0:self.n0, 0:1]
v0 = target[0:self.n0, 1:2]
u_lb_pred = pred[0][self.n0:(self.n0+self.nb), 0:1]
u_ub_pred = pred[0][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
v_lb_pred = pred[1][self.n0:(self.n0+self.nb), 0:1]
v_ub_pred = pred[1][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
ux_lb_pred = pred[2][self.n0:(self.n0+self.nb), 0:1]
ux_ub_pred = pred[2][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
vx_lb_pred = pred[3][self.n0:(self.n0+self.nb), 0:1]
vx_ub_pred = pred[3][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
fu_pred = pred[4][(self.n0+2*self.nb)::, 0:1]
fv_pred = pred[5][(self.n0+2*self.nb)::, 0:1]
mse_u_0 = self.mse(u0_pred, u0)
mse_v_0 = self.mse(v0_pred, v0)
mse_u_b = self.mse(u_lb_pred, u_ub_pred)
mse_v_b = self.mse(v_lb_pred, v_ub_pred)
mse_ux_b = self.mse(ux_lb_pred, ux_ub_pred)
mse_vx_b = self.mse(vx_lb_pred, vx_ub_pred)
mse_fu = self.mse(fu_pred, self.f_target)
mse_fv = self.mse(fv_pred, self.f_target)
ans = mse_u_0 + mse_v_0 + mse_u_b + mse_v_b + mse_ux_b + mse_vx_b + mse_fu + mse_fv
return ans

View File

@ -0,0 +1,166 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define the PINNs network for the Schrodinger equation."""
import numpy as np
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import TruncatedNormal, Zero, initializer
from mindspore.ops import constexpr
import mindspore.common.dtype as mstype
@constexpr
def _generate_ones(batch_size):
arr = np.ones((batch_size, 1), np.float32)
return Tensor(arr, mstype.float32)
@constexpr
def _generate_zeros(batch_size):
arr = np.zeros((batch_size, 1), np.float32)
return Tensor(arr, mstype.float32)
class neural_net(nn.Cell):
"""
Neural net to fit the wave function
Args:
layers (int): num of neurons for each layer
lb (np.array): lower bound (x, t) of domain
ub (np.array): upper bound (x, t) of domain
"""
def __init__(self, layers, lb, ub):
super(neural_net, self).__init__()
self.layers = layers
self.concat = ops.Concat(axis=1)
self.lb = Tensor(lb, mstype.float32)
self.ub = Tensor(ub, mstype.float32)
self.tanh = ops.Tanh()
self.add = ops.Add()
self.matmul = ops.MatMul()
self.w0 = self._init_weight_xavier(0)
self.b0 = self._init_biase(0)
self.w1 = self._init_weight_xavier(1)
self.b1 = self._init_biase(1)
self.w2 = self._init_weight_xavier(2)
self.b2 = self._init_biase(2)
self.w3 = self._init_weight_xavier(3)
self.b3 = self._init_biase(3)
self.w4 = self._init_weight_xavier(4)
self.b4 = self._init_biase(4)
def construct(self, x, t):
"""forward propagation"""
X = self.concat((x, t))
X = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
X = self.tanh(self.add(self.matmul(X, self.w0), self.b0))
X = self.tanh(self.add(self.matmul(X, self.w1), self.b1))
X = self.tanh(self.add(self.matmul(X, self.w2), self.b2))
X = self.tanh(self.add(self.matmul(X, self.w3), self.b3))
X = self.add(self.matmul(X, self.w4), self.b4)
return X[:, 0:1], X[:, 1:2]
def _init_weight_xavier(self, layer):
"""
Initialize weight for the ith layer
"""
in_dim = self.layers[layer]
out_dim = self.layers[layer+1]
std = np.sqrt(2/(in_dim + out_dim))
name = 'w' + str(layer)
return Parameter(default_input=initializer(TruncatedNormal(std), [in_dim, out_dim], mstype.float32),
name=name, requires_grad=True)
def _init_biase(self, layer):
"""
Initialize biase for the ith layer
"""
name = 'b' + str(layer)
return Parameter(default_input=initializer(Zero(), self.layers[layer+1], mstype.float32),
name=name, requires_grad=True)
class Grad_1(nn.Cell):
"""
Using the first output to compute gradient.
"""
def __init__(self, net):
super(Grad_1, self).__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, sens_param=True)
def construct(self, x, t):
sens_1 = _generate_ones(x.shape[0])
sens_2 = _generate_zeros(x.shape[0])
return self.grad(self.net)(x, t, (sens_1, sens_2))
class Grad_2(nn.Cell):
"""
Using the second output to compute gradient.
"""
def __init__(self, net):
super(Grad_2, self).__init__()
self.net = net
self.grad = ops.GradOperation(get_all=True, sens_param=True)
def construct(self, x, t):
sens_1 = _generate_zeros(x.shape[0])
sens_2 = _generate_ones(x.shape[0])
return self.grad(self.net)(x, t, (sens_1, sens_2))
class PINNs(nn.Cell):
"""
PINNs for the Schrodinger equation.
"""
def __init__(self, layers, lb, ub):
super(PINNs, self).__init__()
self.nn = neural_net(layers, lb, ub)
self.du = Grad_1(self.nn)
self.dv = Grad_2(self.nn)
self.dux = Grad_1(self.du)
self.dvx = Grad_1(self.dv)
self.add = ops.Add()
self.pow = ops.Pow()
self.mul = ops.Mul()
def construct(self, X):
"""forward propagation"""
x = X[:, 0:1]
t = X[:, 1:2]
u, v = self.nn(x, t)
ux, ut = self.du(x, t)
vx, vt = self.dv(x, t)
uxx, _ = self.dux(x, t)
vxx, _ = self.dvx(x, t)
square_sum = self.add(self.pow(u, 2), self.pow(v, 2))
fu1 = self.mul(vxx, 0.5)
fu2 = self.mul(square_sum, v)
fu = self.add(self.add(ut, fu1), fu2)
fv1 = self.mul(uxx, -0.5)
fv2 = self.mul(square_sum, u)
fv2 = self.mul(fv2, -1.0)
fv = self.add(self.add(vt, fv1), fv2)
return u, v, ux, vx, fu, fv

View File

@ -0,0 +1,14 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

View File

@ -0,0 +1,22 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Network config setting
"""
# config for Schrodinger equation scenario
config_Sch = {'epoch': 50000, 'lr': 0.0001, 'N0': 50, 'Nb': 50, 'Nf': 20000, 'num_neuron': 100,
'seed': 2, 'path': './Data/NLS.mat', 'ck_path': './ckpoints/'}

View File

@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train PINNs"""
import argparse
import numpy as np
from mindspore import Model, context, nn
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
ModelCheckpoint, TimeMonitor)
from mindspore.common import set_seed
from src import config
from src.Schrodinger.dataset import generate_PINNs_training_set
from src.Schrodinger.loss import PINNs_loss
from src.Schrodinger.net import PINNs
def train_sch(epoch=50000, lr=0.0001, N0=50, Nb=50, Nf=20000, num_neuron=100, seed=None,
path='./Data/NLS.mat', ck_path='./ckpoints/'):
"""
Train PINNs network for Schrodinger equation
Args:
epoch (int): number of epochs
lr (float): learning rate
N0 (int): number of data points sampled from the initial condition,
0<N0<=256 for the default NLS dataset
Nb (int): number of data points sampled from the boundary condition,
0<Nb<=201 for the default NLS dataset. Size of training set = N0+2*Nb
Nf (int): number of collocation points, collocation points are used
to calculate regularizer for the network from Schoringer equation.
0<Nf<=51456 for the default NLS dataset
num_neuron (int): number of neurons for fully connected layer in the network
seed (int): random seed
path (str): path of the dataset for Schrodinger equation
ck_path (str): path to store checkpoint files (.ckpt)
"""
if seed is not None:
np.random.seed(seed)
set_seed(seed)
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
lb = np.array([-5.0, 0.0])
ub = np.array([5.0, np.pi/2])
training_set = generate_PINNs_training_set(N0, Nb, Nf, lb, ub, path=path)
n = PINNs(layers, lb, ub)
opt = nn.Adam(n.trainable_params(), learning_rate=lr)
loss = PINNs_loss(N0, Nb, Nf)
#call back configuration
loss_print_num = 1 # print loss per loss_print_num epochs
# save model
config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=50)
ckpoint = ModelCheckpoint(prefix="checkpoint_PINNs_Schrodinger", directory=ck_path, config=config_ck)
model = Model(network=n, loss_fn=loss, optimizer=opt)
model.train(epoch=epoch, train_dataset=training_set,
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1)], dataset_sink_mode=True)
print('Training complete')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train PINNs')
#only support 'Schrodinger' for now
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
parser.add_argument('--datapath', type=str, help='path for dataset', default='')
args_opt = parser.parse_args()
pinns_scenario = args_opt.scenario
data_path = args_opt.datapath
if pinns_scenario == 'Schrodinger':
conf = config.config_Sch
if data_path != '':
conf['path'] = data_path
train_sch(**conf)
else:
print(f'{pinns_scenario} is not supported in PINNs training for now')