forked from mindspore-Ecosystem/mindspore
PINNs (Schrodinger)
This commit is contained in:
parent
84837b0bc9
commit
bbeaebaf0b
|
@ -0,0 +1,218 @@
|
|||
# Contents
|
||||
|
||||
[查看中文](./README_CN.md)
|
||||
|
||||
- [PINNs Description](#PINNs-Description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Schrodinger equation](#Schrodinger-equation)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Evaluation of Schrodinger equation scenario](#Evaluation-of-Schrodinger-equation-scenario)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [Inference of Schrodinger equation scenario](#Inference-of-Schrodinger-equation-scenario)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
# [PINNs Description](#contents)
|
||||
|
||||
PINNs (Physics Information Neural Networks) is a neural network proposed in 2019. PINNs network provides a new approach for solving partial differential equations with neural network. Partial differential equations are often used in the modeling of physical, biological and engineering systems. The characteristics of such systems have significantly difference from most problems in machine learning: (1) the cost of data acquisition is high, and the amount of data is usually small;(2) a large amount of priori knowledge, such as previous research result like physical laws, are hard to be utilized by machine learning systems.
|
||||
|
||||
In PINNs, firstly the prior knowledge in the form of partial differential equation is introduced as the regularization term of the network through proper construction of the Pinns network. Then, by utilizing the prior knowledge in PINNs, the network can train very good results with very little data.
|
||||
|
||||
[paper](https://www.sciencedirect.com/science/article/pii/S0021999118307125):Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations."*Journal of Computational Physics*. 2019 (378): 686-707.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
|
||||
Pinns is a new framework of constructing neural network for solving partial differential equations. The specific model structure will change according to the partial differential equations. The network structure of each application scenario of PINNs implemented in MindSpore is as follows:
|
||||
|
||||
## [Schrodinger equation](#Contents)
|
||||
|
||||
The PINNs of the Schrodinger equation can be divided into two parts. First, a neural network composed of five fully connected layers is used to fit the wave function to be solved (i.e., the solution of the Schrodinger equation in the quantum mechanics system described by the data set). The neural network has two outputs, which represent the real part and the imaginary part of the wave function respectively. Then, the two outputs are followed by some derivative operations. The Schrodinger equation can be expressed by properly combining these derivative results, and act as a constraint term of the neural network. The outputs of the whole network are the real part, imaginary part and some related partial derivatives of the wave function.
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset once you have downloaded the dataset from the corresponding link to the data storage path (default path is '/PINNs/Data/') . In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [NLS](https://github.com/maziarraissi/PINNs/tree/master/main/Data), can refer to [paper](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
|
||||
|
||||
- Dataset size:546KB,51456 points sampled from the wave function of a one-dimensional quantum mechanics system with periodic boundary conditions.
|
||||
- Train:150 data points
|
||||
- Test:All 51456 data points of the dataset.
|
||||
- Data format:mat files
|
||||
- Note:This dataset is used in the Schrodinger equation scenario. Data will be processed in src/Schrodinger/dataset.py
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## [Mixed Precision](#Contents)
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
After installing MindSpore via the official website, you can start training and evaluation as follows:
|
||||
|
||||
- Schrodinger equation scenario running on GPU
|
||||
|
||||
```shell
|
||||
# Running training example
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python train.py --scenario=Schrodinger --datapath=./Data/mat > train.log
|
||||
OR
|
||||
bash /scripts/run_standalone_Schrodinger_train.sh Schrodinger
|
||||
|
||||
# Running evaluation example
|
||||
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
|
||||
OR
|
||||
bash /scriptsrun_standalone_Schrodinger_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
|
||||
```text
|
||||
├── model_zoo
|
||||
├── README.md // descriptions about all the models
|
||||
├── PINNs
|
||||
├── README.md // descriptions about PINNs
|
||||
├── scripts
|
||||
│ ├──run_standalone_Schrodinger_train.sh // shell script for Schrodinger equation scenario training on GPU
|
||||
| ├──run_standalone_Schrodinger_eval.sh // shell script for Schrodinger equation scenario evaluation on GPU
|
||||
├── src
|
||||
| ├──Schrodinger //Schrodinger equation scenario
|
||||
│ | ├──dataset.py // creating dataset
|
||||
│ | ├──net.py // PINNs (Schrodinger) architecture
|
||||
│ ├──config.py // parameter configuration
|
||||
├── train.py // training script (Schrodinger)
|
||||
├── eval.py // evaluation script (Schrodinger)
|
||||
├── export.py // export checkpoint files into mindir (Schrodinger) ├── ├── requirements // additional packages required to run PINNs networks
|
||||
```
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py
|
||||
|
||||
- config for Schrodinger equation scenario
|
||||
|
||||
```python
|
||||
'epoch':50000 # number of epochs in training
|
||||
'lr':0.0001 # learning rate
|
||||
'N0':50 # number of sampling points of the training set at the initial condition. For the NLS dataset, 0<N0<=256
|
||||
'Nb':50 # number of sampling points of the training set at the boundary condition. For the NLS dataset, 0<N0<=201
|
||||
'Nf':20000 # number of collocations points used to calculate the constraint of Schrodinger equation in training. For the NLS dataset, 0<Nf<=51456
|
||||
'num_neuron':100 # number of neurons in fully connected hidden layer of PINNs network for Schrodinger equation
|
||||
'seed':2 # random seed
|
||||
'path':'./Data/NLS.mat' # data set storage path
|
||||
'ck_path':'./ckpoints/' # path to save checkpoint files (.ckpt)
|
||||
```
|
||||
|
||||
For more configuration details, please refer the script `config.py`.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Running Schrodinger equation scenario on GPU
|
||||
|
||||
```bash
|
||||
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log 2>&1 &
|
||||
```
|
||||
|
||||
- The python command above will run in the background, you can view the results through the file `train.log`。
|
||||
|
||||
The loss value can be achieved as follows:
|
||||
|
||||
```bash
|
||||
# grep "loss is " train.log
|
||||
epoch: 1 step: 1, loss is 1.3523688
|
||||
epoch time: 7519.499 ms, per step time: 7519.499 ms
|
||||
epcoh: 2 step: 1, loss is 1.2859955
|
||||
epoch time: 429.470 ms
|
||||
...
|
||||
```
|
||||
|
||||
After training, you'll get some checkpoint files under the folder `./ckpoints/` by default.
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
- evaluation of Schrodinger equation scenario when running on GPU
|
||||
|
||||
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., “./ckpt/checkpoint_PINNs_Schrodinger-50000_1.ckpt”。
|
||||
|
||||
```bash
|
||||
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
|
||||
```
|
||||
|
||||
The above python command will run in the background. You can view the results through the file "eval.log". The error of evaluation is as follows:
|
||||
|
||||
```bash
|
||||
# grep "accuracy:" eval.log
|
||||
evaluation error is: 0.01207
|
||||
```
|
||||
|
||||
# [Model Description](#contents)
|
||||
|
||||
## [Performance](#contents)
|
||||
|
||||
### [Evaluation Performance](#contents)
|
||||
|
||||
#### [Evaluation of Schrodinger equation scenario](#contents)
|
||||
|
||||
| Parameters | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------ |
|
||||
| Model Version | PINNs (Schrodinger) |
|
||||
| Resource | NV Tesla V100-32G |
|
||||
| uploaded Date | 5/20/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | NLS |
|
||||
| Training Parameters | epoch=50000, lr=0.0001. see src/config.py for details |
|
||||
| Optimizer | Adam |
|
||||
| Loss Function | src/Schrodinger/loss.py |
|
||||
| outputs | the wave function (real part and imaginary part),first order derivative of the wave function to the coordinates (real part and imaginary part),the fitting of the Schrodinger equation (real part and imaginary part) |
|
||||
| Loss | 0.00009928 |
|
||||
| Speed | 456ms/step |
|
||||
| Total time | 6.3344 hours |
|
||||
| Parameters | 32K |
|
||||
| Checkpoint for Fine tuning | 363K (.ckpt file) |
|
||||
|
||||
### [Inference Performance](#contents)
|
||||
|
||||
#### [Inference of Schrodinger equation scenario](#contents)
|
||||
|
||||
| Parameters | GPU |
|
||||
| ----------------- | -------------------------------------------- |
|
||||
| Model Version | PINNs (Schrodinger) |
|
||||
| Resource | NV Tesla V100-32G |
|
||||
| uploaded Date | 5/20/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | NLS |
|
||||
| outputs | real part and imaginary of the wave function |
|
||||
| mean square error | 0.01323 |
|
||||
|
||||
# [Description of Random Situation](#contents)
|
||||
|
||||
We use random seed in train.py,which can be reset in src/config.py.
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,219 @@
|
|||
# 目录
|
||||
|
||||
[View English](./README.md)
|
||||
|
||||
- [目录](#目录)
|
||||
- [PINNs描述](#PINNs描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [Schrodinger方程](#Schrodinger方程)
|
||||
- [数据集](#数据集)
|
||||
- [特性](#特性)
|
||||
- [混合精度](#混合精度)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本及样例代码](#脚本及样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [Schrodinger方程场景评估](#Schrodinger方程场景评估)
|
||||
- [推理性能](#推理性能)
|
||||
- [Schrodinger方程场景推理](#Schrodinger方程场景推理)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#modelzoo主页)
|
||||
|
||||
# [PINNs描述](#目录)
|
||||
|
||||
PINNs (Physics-informed neural networks)是2019年提出的神经网络。PINNs网络提供了一种全新的用神经网络求解偏微分方程的思路。对现实的物理、生物、工程等系统建模时,常常会用到偏微分方程。而此类问题的特征与机器学习中遇到的大多数问题有两点显著不同:(1)获取数据的成本较高,数据量通常较小;(2)存在大量前人对于此类问题的研究成果作为先验知识而无法被机器学习系统利用,例如各种物理定律等。PINNs网络首先通过适当的构造,将偏微分方程形式的先验知识作为网络的正则化约束引入,进而通过利用这些先验知识强大的约束作用,使得网络能够用很少的数据就训练出很好的结果。PINNs网络在量子力学等场景中经过了成功的验证,能够用很少的数据成功训练网络并对相应的物理系统进行建模。
|
||||
|
||||
[论文](https://www.sciencedirect.com/science/article/pii/S0021999118307125):Raissi, Maziar, Paris Perdikaris, and George E. Karniadakis. "Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations."*Journal of Computational Physics*. 2019 (378): 686-707.
|
||||
|
||||
# [模型架构](#目录)
|
||||
|
||||
PINNs是针对偏微分方程问题构造神经网络的思路,具体的模型结构会根据所要求解的偏微分方程而有相应的变化。在MindSpore中实现的PINNs各应用场景的网络结构如下:
|
||||
|
||||
## [Schrodinger方程](#目录)
|
||||
|
||||
针对Schrodinger方程的PINNs分为两部分,首先是一个由5个全连接层组成的神经网络用来拟合待求解的波函数(即薛定谔方程在数据集所描述的量子力学系统下的解)。该神经网络有2个输出分别表示波函数的实部和虚部。之后在这两个输出后面接上一些求导的操作,将这些求导的结果适当的组合起来就可以表示Schrodinger方程,作为神经网络的约束项。将波函数的实部、虚部以及一些相关的偏导数作为整个网络的输出。
|
||||
|
||||
# [数据集](#目录)
|
||||
|
||||
从数据集相应的链接中下载数据集至指定目录(默认'/PINNs/Data/')后可运行相关脚本。文档的后面会介绍如何使用相关脚本。
|
||||
|
||||
使用的数据集:[NLS](https://github.com/maziarraissi/PINNs/tree/master/main/Data), 可参照[论文](https://www.sciencedirect.com/science/article/pii/S0021999118307125)
|
||||
|
||||
- 数据集大小:546KB,对一维周期性边界量子力学系统波函数的51456个采样点。
|
||||
- 训练集:150个点
|
||||
- 测试集:整个数据集的全部51456个采样点
|
||||
- 数据格式:mat文件
|
||||
- 注:该数据集在Schrodinger方程场景中使用。数据将在src/Schrodinger/dataset.py中处理。
|
||||
|
||||
# [特性](#目录)
|
||||
|
||||
## [混合精度](#目录)
|
||||
|
||||
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
|
||||
以FP16算子为例,如果输入数据类型为FP32,MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志,搜索“reduce precision”查看精度降低的算子。
|
||||
|
||||
# [环境要求](#目录)
|
||||
|
||||
- 硬件(GPU)
|
||||
- 使用GPU处理器来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
# [快速入门](#目录)
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
- GPU处理器环境运行Schrodinger方程场景
|
||||
|
||||
```shell
|
||||
# 运行训练示例
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log
|
||||
OR
|
||||
bash /scripts/run_standalone_Schrodinger_train.sh [DATASET_PATH]
|
||||
|
||||
# 运行评估示例
|
||||
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
|
||||
OR
|
||||
bash /scriptsrun_standalone_Schrodinger_eval.sh [CHECKPOINT_PATH] [DATASET_PATH]
|
||||
```
|
||||
|
||||
# [脚本说明](#目录)
|
||||
|
||||
## [脚本及样例代码](#目录)
|
||||
|
||||
```text
|
||||
├── model_zoo
|
||||
├── README.md // 所有模型相关说明
|
||||
├── PINNs
|
||||
├── README.md // PINNs相关说明
|
||||
├── scripts
|
||||
│ ├──run_standalone_Schrodinger_train.sh // Schrodinger方程GPU训练的shell脚本
|
||||
| ├──run_standalone_Schrodinger_eval.sh // Schrodinger方程GPU评估的shell脚本
|
||||
├── src
|
||||
| ├──Schrodinger //Schrodinger方程场景
|
||||
│ | ├──dataset.py //创建数据集
|
||||
│ | ├──net.py // PINNs (Schrodinger) 架构
|
||||
│ ├──config.py // 参数配置
|
||||
├── train.py // 训练脚本 (Schrodinger)
|
||||
├── eval.py // 评估脚本 (Schrodinger)
|
||||
├── export.py // 将checkpoint文件导出为mindir
|
||||
├── requirements // 运行PINNs网络额外需要的包
|
||||
```
|
||||
|
||||
## [脚本参数](#目录)
|
||||
|
||||
在config.py中可以同时配置训练参数和评估参数。
|
||||
|
||||
- 配置Schrodinger方程场景。
|
||||
|
||||
```python
|
||||
'epoch':50000 #训练轮次
|
||||
'lr':0.0001 #学习率
|
||||
'N0':50 #训练集在初始条件处的采样点数量,对于NLS数据集,0<N0<=256
|
||||
'Nb':50 #训练集在边界条件处的采样点数量,对于NLS数据集,0<Nb<=201
|
||||
'Nf':20000 #训练时用于计算Schrodinger方程约束的配点数
|
||||
'num_neuron':100 #PINNs网络全连接隐藏层的神经元数量
|
||||
'seed':2 #随机种子
|
||||
'path':'./Data/NLS.mat' #数据集存储路径
|
||||
'ck_path':'./ckpoints/' #保存checkpoint文件(.ckpt)的路径
|
||||
```
|
||||
|
||||
更多配置细节请参考脚本`config.py`。
|
||||
|
||||
## [训练过程](#目录)
|
||||
|
||||
- GPU处理器环境运行Schrodinger方程场景
|
||||
|
||||
```bash
|
||||
python train.py --scenario=Schrodinger --datapath=[DATASET_PATH] > train.log 2>&1 &
|
||||
```
|
||||
|
||||
- 以上python命令将在后台运行。您可以通过train.log文件查看结果。
|
||||
|
||||
可以采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
# grep "loss is " train.log
|
||||
epoch: 1 step: 1, loss is 1.3523688
|
||||
epoch time: 7519.499 ms, per step time: 7519.499 ms
|
||||
epcoh: 2 step: 1, loss is 1.2859955
|
||||
epoch time: 429.470 ms
|
||||
...
|
||||
```
|
||||
|
||||
训练结束后,您可在默认`./ckpoints/`脚本文件夹下找到检查点文件。
|
||||
|
||||
## [评估过程](#目录)
|
||||
|
||||
- 在GPU处理器环境运行Schrodinger方程场景
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“./ckpt/checkpoint_PINNs_Schrodinger-50000_1.ckpt”。
|
||||
|
||||
```bash
|
||||
python eval.py [CHECKPOINT_PATH] --scenario=Schrodinger ----datapath=[DATASET_PATH] > eval.log
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试误差如下:
|
||||
|
||||
```bash
|
||||
# grep "accuracy:" eval.log
|
||||
evaluation error is: 0.01207
|
||||
```
|
||||
|
||||
# [模型描述](#目录)
|
||||
|
||||
## [性能](#目录)
|
||||
|
||||
### [评估性能](#目录)
|
||||
|
||||
#### [Schrodinger方程场景评估](#目录)
|
||||
|
||||
| 参数 | GPU |
|
||||
| -------------------------- | ---------------------- |
|
||||
| 模型版本 | PINNs (Schrodinger) |
|
||||
| 资源 | NV Tesla V100-32G |
|
||||
| 上传日期 | 2021-5-20 |
|
||||
| MindSpore版本 | 1.2.0 |
|
||||
| 数据集 | NLS |
|
||||
| 训练参数 | epoch=50000, lr=0.0001. 详见src/config.py |
|
||||
| 优化器 | Adam |
|
||||
| 损失函数 | src/Schrodinger/loss.py |
|
||||
| 输出 | 波函数(实部,虚部),波函数对坐标的一阶导(实部,虚部),对薛定谔方程的拟合(实部,虚部) |
|
||||
| 损失 | 0.00009928 |
|
||||
| 速度 | 456毫秒/步 |
|
||||
| 总时长 | 6.3344 小时 |
|
||||
| 参数 | 32K |
|
||||
| 微调检查点 | 363K (.ckpt文件) |
|
||||
|
||||
### [推理性能](#目录)
|
||||
|
||||
#### [Schrodinger方程场景推理](#目录)
|
||||
|
||||
| 参数 | GPU |
|
||||
| ------------------- | --------------------------- |
|
||||
| 模型版本 | PINNs (Schrodinger) |
|
||||
| 资源 | NV Tesla V100-32G |
|
||||
| 上传日期 | 2021-5-20 |
|
||||
| MindSpore 版本 | 1.2.0 |
|
||||
| 数据集 | NLS |
|
||||
| 输出 | 波函数的实部与虚部 |
|
||||
| 均方误差 | 0.01323 |
|
||||
|
||||
# [随机情况说明](#目录)
|
||||
|
||||
在train.py中的使用了随机种子,可在src/config.py中修改。
|
||||
|
||||
# [ModelZoo主页](#目录)
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Eval"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Tensor, context
|
||||
from mindspore import load_checkpoint, load_param_into_net
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src import config
|
||||
from src.Schrodinger.dataset import get_eval_data
|
||||
from src.Schrodinger.net import PINNs
|
||||
|
||||
def eval_PINNs_sch(ckpoint_name, num_neuron=100, path='./Data/NLS.mat'):
|
||||
"""
|
||||
Evaluation of PINNs for Schrodinger equation scenario.
|
||||
|
||||
Args:
|
||||
ckpoint_name (str): model checkpoint file name
|
||||
num_neuron (int): number of neurons for fully connected layer in the network
|
||||
path (str): path of the dataset for Schrodinger equation
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
|
||||
lb = np.array([-5.0, 0.0])
|
||||
ub = np.array([5.0, np.pi/2])
|
||||
|
||||
n = PINNs(layers, lb, ub)
|
||||
param_dict = load_checkpoint(ckpoint_name)
|
||||
load_param_into_net(n, param_dict)
|
||||
|
||||
X_star, _, _, h_star = get_eval_data(path)
|
||||
|
||||
X_tensor = Tensor(X_star, mstype.float32)
|
||||
pred = n(X_tensor)
|
||||
u_pred = pred[0].asnumpy()
|
||||
v_pred = pred[1].asnumpy()
|
||||
h_pred = np.sqrt(u_pred**2 + v_pred**2)
|
||||
|
||||
error_h = np.linalg.norm(h_star-h_pred, 2)/np.linalg.norm(h_star, 2)
|
||||
print(f'evaluation error is: {error_h}')
|
||||
|
||||
return error_h
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Evaluate PINNs for Schrodinger equation scenario')
|
||||
parser.add_argument('ck_file', type=str, help='model checkpoint(ckpt) filename')
|
||||
|
||||
#only support 'Schrodinger' for now
|
||||
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
|
||||
parser.add_argument('--datapath', type=str, help='path for dataset', default='')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
f_name = args_opt.ck_file
|
||||
pinns_scenario = args_opt.scenario
|
||||
data_path = args_opt.datapath
|
||||
|
||||
if pinns_scenario == 'Schrodinger':
|
||||
conf = config.config_Sch
|
||||
hidden_size = conf['num_neuron']
|
||||
if data_path == '':
|
||||
dataset_path = conf['path']
|
||||
else:
|
||||
dataset_path = data_path
|
||||
mse_error = eval_PINNs_sch(f_name, hidden_size, dataset_path)
|
||||
else:
|
||||
print(f'{pinns_scenario} is not supported in PINNs evaluation for now')
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import (Tensor, context, export, load_checkpoint,
|
||||
load_param_into_net)
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src import config
|
||||
from src.Schrodinger.net import PINNs
|
||||
|
||||
parser = argparse.ArgumentParser(description='PINNs export')
|
||||
parser.add_argument('ck_file', type=str, help='model checkpoint(ckpt) filename')
|
||||
parser.add_argument('file_name', type=str, help='export file name')
|
||||
|
||||
#only support ‘Schrodinger' for now
|
||||
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
|
||||
|
||||
|
||||
def export_sch(conf_sch, export_format, export_name):
|
||||
"""
|
||||
export PINNs for Schrodinger model
|
||||
|
||||
Args:
|
||||
conf_sch (dict): dictionary for configuration, see src/config.py for details
|
||||
export_format (str): file format to export
|
||||
export_name (str): name of exported file
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
num_neuron = conf_sch['num_neuron']
|
||||
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
|
||||
lb = np.array([-5.0, 0.0])
|
||||
ub = np.array([5.0, np.pi/2])
|
||||
|
||||
n = PINNs(layers, lb, ub)
|
||||
param_dict = load_checkpoint(ck_file)
|
||||
load_param_into_net(n, param_dict)
|
||||
|
||||
batch_size = conf_sch['N0'] + 2*conf_sch['Nb'] +conf_sch['Nf']
|
||||
inputs = Tensor(np.ones((batch_size, 2)), mstype.float32)
|
||||
export(n, inputs, file_name=export_name, file_format=export_format)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = parser.parse_args()
|
||||
ck_file = args_opt.ck_file
|
||||
file_format = 'MINDIR'
|
||||
file_name = args_opt.file_name
|
||||
pinns_scenario = args_opt.scenario
|
||||
conf = config.config_Sch
|
||||
|
||||
if pinns_scenario == 'Schrodinger':
|
||||
export_sch(conf, file_format, file_name)
|
||||
else:
|
||||
print(f'{pinns_scenario} scenario in PINNs is not supported to export for now')
|
|
@ -0,0 +1 @@
|
|||
pyDOE>=0.3.8
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 2 ] && [ $# != 3 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_Schrodinger_eval.sh [CHECKPOINT] [DATASET] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash scripts/run_standalone_Schrodinger_eval.sh ckpoints/checkpoint_PINNs_Schrodinger-50000_1.ckptData/NLS.mat 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
|
||||
export DEVICE_ID=0
|
||||
if [ $# == 3 ];
|
||||
then
|
||||
export DEVICE_ID=$3
|
||||
fi
|
||||
|
||||
ck_path=$(get_real_path $1)
|
||||
data_set_path=$(get_real_path $2)
|
||||
python ${PROJECT_DIR}/../eval.py $ck_path --scenario=Schrodinger --datapath=$data_set_path > eval.log 2>&1 &
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_Schrodinger_train.sh [DATASET] [DEVICE_ID](option, default is 0)"
|
||||
echo "for example: bash scripts/run_standalone_Schrodinger_train.sh Data/NLS.mat 0"
|
||||
echo "=============================================================================================================="
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
data_set_path=$(get_real_path $1)
|
||||
|
||||
export DEVICE_ID=0
|
||||
if [ $# == 2 ];
|
||||
then
|
||||
export DEVICE_ID=$2
|
||||
fi
|
||||
|
||||
python ${PROJECT_DIR}/../train.py --datapath $data_set_path --scenario Schrodinger > train.log 2>&1 &
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create dataset for training or evaluation"""
|
||||
import mindspore.dataset as ds
|
||||
import numpy as np
|
||||
import scipy.io as scio
|
||||
from pyDOE import lhs
|
||||
|
||||
|
||||
class PINNs_training_set:
|
||||
"""
|
||||
Training set for PINNs (Schrodinger)
|
||||
|
||||
Args:
|
||||
N0 (int): number of sampled training data points for the initial condition
|
||||
Nb (int): number of sampled training data points for the boundary condition
|
||||
Nf (int): number of sampled training data points for the collocation points
|
||||
lb (np.array): lower bound (x, t) of domain
|
||||
ub (np.array): upper bound (x, t) of domain
|
||||
"""
|
||||
def __init__(self, N0, Nb, Nf, lb, ub, path='./Data/NLS.mat'):
|
||||
data = scio.loadmat(path)
|
||||
self.N0 = N0
|
||||
self.Nb = Nb
|
||||
self.Nf = Nf
|
||||
self.lb = lb
|
||||
self.ub = ub
|
||||
|
||||
# load data
|
||||
t = data['tt'].flatten()[:, None]
|
||||
x = data['x'].flatten()[:, None]
|
||||
Exact = data['uu']
|
||||
Exact_u = np.real(Exact)
|
||||
Exact_v = np.imag(Exact)
|
||||
|
||||
idx_x = np.random.choice(x.shape[0], self.N0, replace=False)
|
||||
self.x0 = x[idx_x, :]
|
||||
self.u0 = Exact_u[idx_x, 0:1]
|
||||
self.v0 = Exact_v[idx_x, 0:1]
|
||||
|
||||
idx_t = np.random.choice(t.shape[0], self.Nb, replace=False)
|
||||
self.tb = t[idx_t, :]
|
||||
|
||||
self.X_f = self.lb + (self.ub-self.lb)*lhs(2, self.Nf)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if index < self.N0: # N0 initial points
|
||||
x = np.array([self.x0[index][0]], np.float32)
|
||||
t = np.array([0], np.float32)
|
||||
u_target = np.array(self.u0[index], np.float32)
|
||||
v_target = np.array(self.v0[index], np.float32)
|
||||
|
||||
elif self.N0 <= index < self.N0+self.Nb: # Nb lower bound points
|
||||
ind = index - self.N0
|
||||
x = np.array([self.lb[0]], np.float32)
|
||||
t = np.array([self.tb[ind][0]], np.float32)
|
||||
u_target = np.array([self.ub[0]], np.float32)
|
||||
v_target = t
|
||||
|
||||
elif self.N0+self.Nb <= index < self.N0+2*self.Nb: # Nb upper bound points
|
||||
ind = index - self.N0 - self.Nb
|
||||
x = np.array([self.ub[0]], np.float32)
|
||||
t = np.array([self.tb[ind][0]], np.float32)
|
||||
u_target = np.array([self.lb[0]], np.float32)
|
||||
v_target = t
|
||||
|
||||
else: # Nf collocation points
|
||||
ind = index - self.N0 - 2*self.Nb
|
||||
x = np.array(self.X_f[ind, 0:1], np.float32)
|
||||
t = np.array(self.X_f[ind, 1:2], np.float32)
|
||||
u_target = np.array([0], np.float32)
|
||||
v_target = np.array([0], np.float32)
|
||||
|
||||
return np.hstack((x, t)), np.hstack((u_target, v_target))
|
||||
|
||||
def __len__(self):
|
||||
return self.N0+2*self.Nb+self.Nf
|
||||
|
||||
|
||||
def generate_PINNs_training_set(N0, Nb, Nf, lb, ub, path='./Data/NLS.mat'):
|
||||
"""
|
||||
Generate training set for PINNs
|
||||
|
||||
Args: see class PINNs_train_set
|
||||
"""
|
||||
s = PINNs_training_set(N0, Nb, Nf, lb, ub, path)
|
||||
dataset = ds.GeneratorDataset(source=s, column_names=['data', 'label'], shuffle=False)
|
||||
dataset = dataset.batch(batch_size=len(s))
|
||||
return dataset
|
||||
|
||||
|
||||
def get_eval_data(path):
|
||||
"""
|
||||
Get the evaluation data for Schrodinger equation.
|
||||
"""
|
||||
data = scio.loadmat(path)
|
||||
t = data['tt'].astype(np.float32).flatten()[:, None]
|
||||
x = data['x'].astype(np.float32).flatten()[:, None]
|
||||
Exact = data['uu']
|
||||
Exact_u = np.real(Exact).astype(np.float32)
|
||||
Exact_v = np.imag(Exact).astype(np.float32)
|
||||
Exact_h = np.sqrt(Exact_u**2 + Exact_v**2)
|
||||
|
||||
X, T = np.meshgrid(x, t)
|
||||
|
||||
X_star = np.hstack((X.flatten()[:, None], T.flatten()[:, None]))
|
||||
u_star = Exact_u.T.flatten()[:, None]
|
||||
v_star = Exact_v.T.flatten()[:, None]
|
||||
h_star = Exact_h.T.flatten()[:, None]
|
||||
|
||||
return X_star, u_star, v_star, h_star
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Loss function for PINNs (Schrodinger)"""
|
||||
from mindspore import nn, ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class PINNs_loss(nn.Cell):
|
||||
"""
|
||||
Loss of the PINNs network, only works with full-batch training. Training data are organized in
|
||||
the following order: initial condition points ([0:n0]), boundary condition points ([n0:(n0+2*nb)]),
|
||||
collocation points ([(n0+2*nb)::])
|
||||
"""
|
||||
def __init__(self, n0, nb, nf, reduction='sum'):
|
||||
super(PINNs_loss, self).__init__(reduction)
|
||||
self.n0 = n0
|
||||
self.nb = nb
|
||||
self.nf = nf
|
||||
self.zeros = ops.Zeros()
|
||||
self.mse = nn.MSELoss(reduction='mean')
|
||||
self.f_target = self.zeros((self.nf, 1), mstype.float32)
|
||||
|
||||
def construct(self, pred, target):
|
||||
"""
|
||||
pred: prediction value (u, v, ux, vx, fu, fv)
|
||||
target: target[:, 0:1] = u_target, target[:, 0:2] = v_target
|
||||
"""
|
||||
u0_pred = pred[0][0:self.n0, 0:1]
|
||||
u0 = target[0:self.n0, 0:1]
|
||||
v0_pred = pred[1][0:self.n0, 0:1]
|
||||
v0 = target[0:self.n0, 1:2]
|
||||
|
||||
u_lb_pred = pred[0][self.n0:(self.n0+self.nb), 0:1]
|
||||
u_ub_pred = pred[0][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
|
||||
v_lb_pred = pred[1][self.n0:(self.n0+self.nb), 0:1]
|
||||
v_ub_pred = pred[1][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
|
||||
|
||||
ux_lb_pred = pred[2][self.n0:(self.n0+self.nb), 0:1]
|
||||
ux_ub_pred = pred[2][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
|
||||
vx_lb_pred = pred[3][self.n0:(self.n0+self.nb), 0:1]
|
||||
vx_ub_pred = pred[3][(self.n0+self.nb):(self.n0+2*self.nb), 0:1]
|
||||
|
||||
fu_pred = pred[4][(self.n0+2*self.nb)::, 0:1]
|
||||
fv_pred = pred[5][(self.n0+2*self.nb)::, 0:1]
|
||||
|
||||
mse_u_0 = self.mse(u0_pred, u0)
|
||||
mse_v_0 = self.mse(v0_pred, v0)
|
||||
mse_u_b = self.mse(u_lb_pred, u_ub_pred)
|
||||
mse_v_b = self.mse(v_lb_pred, v_ub_pred)
|
||||
mse_ux_b = self.mse(ux_lb_pred, ux_ub_pred)
|
||||
mse_vx_b = self.mse(vx_lb_pred, vx_ub_pred)
|
||||
mse_fu = self.mse(fu_pred, self.f_target)
|
||||
mse_fv = self.mse(fv_pred, self.f_target)
|
||||
|
||||
ans = mse_u_0 + mse_v_0 + mse_u_b + mse_v_b + mse_ux_b + mse_vx_b + mse_fu + mse_fv
|
||||
|
||||
return ans
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define the PINNs network for the Schrodinger equation."""
|
||||
import numpy as np
|
||||
from mindspore import Parameter, Tensor, nn, ops
|
||||
from mindspore.common.initializer import TruncatedNormal, Zero, initializer
|
||||
from mindspore.ops import constexpr
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_ones(batch_size):
|
||||
arr = np.ones((batch_size, 1), np.float32)
|
||||
return Tensor(arr, mstype.float32)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_zeros(batch_size):
|
||||
arr = np.zeros((batch_size, 1), np.float32)
|
||||
return Tensor(arr, mstype.float32)
|
||||
|
||||
|
||||
class neural_net(nn.Cell):
|
||||
"""
|
||||
Neural net to fit the wave function
|
||||
|
||||
Args:
|
||||
layers (int): num of neurons for each layer
|
||||
lb (np.array): lower bound (x, t) of domain
|
||||
ub (np.array): upper bound (x, t) of domain
|
||||
"""
|
||||
def __init__(self, layers, lb, ub):
|
||||
super(neural_net, self).__init__()
|
||||
self.layers = layers
|
||||
self.concat = ops.Concat(axis=1)
|
||||
self.lb = Tensor(lb, mstype.float32)
|
||||
self.ub = Tensor(ub, mstype.float32)
|
||||
|
||||
self.tanh = ops.Tanh()
|
||||
self.add = ops.Add()
|
||||
self.matmul = ops.MatMul()
|
||||
|
||||
self.w0 = self._init_weight_xavier(0)
|
||||
self.b0 = self._init_biase(0)
|
||||
self.w1 = self._init_weight_xavier(1)
|
||||
self.b1 = self._init_biase(1)
|
||||
self.w2 = self._init_weight_xavier(2)
|
||||
self.b2 = self._init_biase(2)
|
||||
self.w3 = self._init_weight_xavier(3)
|
||||
self.b3 = self._init_biase(3)
|
||||
self.w4 = self._init_weight_xavier(4)
|
||||
self.b4 = self._init_biase(4)
|
||||
|
||||
def construct(self, x, t):
|
||||
"""forward propagation"""
|
||||
X = self.concat((x, t))
|
||||
X = 2.0*(X - self.lb)/(self.ub - self.lb) - 1.0
|
||||
|
||||
X = self.tanh(self.add(self.matmul(X, self.w0), self.b0))
|
||||
X = self.tanh(self.add(self.matmul(X, self.w1), self.b1))
|
||||
X = self.tanh(self.add(self.matmul(X, self.w2), self.b2))
|
||||
X = self.tanh(self.add(self.matmul(X, self.w3), self.b3))
|
||||
X = self.add(self.matmul(X, self.w4), self.b4)
|
||||
|
||||
return X[:, 0:1], X[:, 1:2]
|
||||
|
||||
def _init_weight_xavier(self, layer):
|
||||
"""
|
||||
Initialize weight for the ith layer
|
||||
"""
|
||||
in_dim = self.layers[layer]
|
||||
out_dim = self.layers[layer+1]
|
||||
std = np.sqrt(2/(in_dim + out_dim))
|
||||
name = 'w' + str(layer)
|
||||
return Parameter(default_input=initializer(TruncatedNormal(std), [in_dim, out_dim], mstype.float32),
|
||||
name=name, requires_grad=True)
|
||||
|
||||
def _init_biase(self, layer):
|
||||
"""
|
||||
Initialize biase for the ith layer
|
||||
"""
|
||||
name = 'b' + str(layer)
|
||||
return Parameter(default_input=initializer(Zero(), self.layers[layer+1], mstype.float32),
|
||||
name=name, requires_grad=True)
|
||||
|
||||
|
||||
class Grad_1(nn.Cell):
|
||||
"""
|
||||
Using the first output to compute gradient.
|
||||
"""
|
||||
def __init__(self, net):
|
||||
super(Grad_1, self).__init__()
|
||||
self.net = net
|
||||
self.grad = ops.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, t):
|
||||
sens_1 = _generate_ones(x.shape[0])
|
||||
sens_2 = _generate_zeros(x.shape[0])
|
||||
return self.grad(self.net)(x, t, (sens_1, sens_2))
|
||||
|
||||
|
||||
class Grad_2(nn.Cell):
|
||||
"""
|
||||
Using the second output to compute gradient.
|
||||
"""
|
||||
def __init__(self, net):
|
||||
super(Grad_2, self).__init__()
|
||||
self.net = net
|
||||
self.grad = ops.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, t):
|
||||
sens_1 = _generate_zeros(x.shape[0])
|
||||
sens_2 = _generate_ones(x.shape[0])
|
||||
return self.grad(self.net)(x, t, (sens_1, sens_2))
|
||||
|
||||
|
||||
class PINNs(nn.Cell):
|
||||
"""
|
||||
PINNs for the Schrodinger equation.
|
||||
"""
|
||||
def __init__(self, layers, lb, ub):
|
||||
super(PINNs, self).__init__()
|
||||
self.nn = neural_net(layers, lb, ub)
|
||||
self.du = Grad_1(self.nn)
|
||||
self.dv = Grad_2(self.nn)
|
||||
self.dux = Grad_1(self.du)
|
||||
self.dvx = Grad_1(self.dv)
|
||||
|
||||
self.add = ops.Add()
|
||||
self.pow = ops.Pow()
|
||||
self.mul = ops.Mul()
|
||||
|
||||
def construct(self, X):
|
||||
"""forward propagation"""
|
||||
x = X[:, 0:1]
|
||||
t = X[:, 1:2]
|
||||
u, v = self.nn(x, t)
|
||||
ux, ut = self.du(x, t)
|
||||
vx, vt = self.dv(x, t)
|
||||
uxx, _ = self.dux(x, t)
|
||||
vxx, _ = self.dvx(x, t)
|
||||
|
||||
square_sum = self.add(self.pow(u, 2), self.pow(v, 2))
|
||||
|
||||
fu1 = self.mul(vxx, 0.5)
|
||||
fu2 = self.mul(square_sum, v)
|
||||
fu = self.add(self.add(ut, fu1), fu2)
|
||||
|
||||
fv1 = self.mul(uxx, -0.5)
|
||||
fv2 = self.mul(square_sum, u)
|
||||
fv2 = self.mul(fv2, -1.0)
|
||||
fv = self.add(self.add(vt, fv1), fv2)
|
||||
|
||||
return u, v, ux, vx, fu, fv
|
|
@ -0,0 +1,14 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Network config setting
|
||||
"""
|
||||
|
||||
|
||||
# config for Schrodinger equation scenario
|
||||
config_Sch = {'epoch': 50000, 'lr': 0.0001, 'N0': 50, 'Nb': 50, 'Nf': 20000, 'num_neuron': 100,
|
||||
'seed': 2, 'path': './Data/NLS.mat', 'ck_path': './ckpoints/'}
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train PINNs"""
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from mindspore import Model, context, nn
|
||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||
ModelCheckpoint, TimeMonitor)
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src import config
|
||||
from src.Schrodinger.dataset import generate_PINNs_training_set
|
||||
from src.Schrodinger.loss import PINNs_loss
|
||||
from src.Schrodinger.net import PINNs
|
||||
|
||||
def train_sch(epoch=50000, lr=0.0001, N0=50, Nb=50, Nf=20000, num_neuron=100, seed=None,
|
||||
path='./Data/NLS.mat', ck_path='./ckpoints/'):
|
||||
"""
|
||||
Train PINNs network for Schrodinger equation
|
||||
|
||||
Args:
|
||||
epoch (int): number of epochs
|
||||
lr (float): learning rate
|
||||
N0 (int): number of data points sampled from the initial condition,
|
||||
0<N0<=256 for the default NLS dataset
|
||||
Nb (int): number of data points sampled from the boundary condition,
|
||||
0<Nb<=201 for the default NLS dataset. Size of training set = N0+2*Nb
|
||||
Nf (int): number of collocation points, collocation points are used
|
||||
to calculate regularizer for the network from Schoringer equation.
|
||||
0<Nf<=51456 for the default NLS dataset
|
||||
num_neuron (int): number of neurons for fully connected layer in the network
|
||||
seed (int): random seed
|
||||
path (str): path of the dataset for Schrodinger equation
|
||||
ck_path (str): path to store checkpoint files (.ckpt)
|
||||
"""
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
layers = [2, num_neuron, num_neuron, num_neuron, num_neuron, 2]
|
||||
|
||||
lb = np.array([-5.0, 0.0])
|
||||
ub = np.array([5.0, np.pi/2])
|
||||
|
||||
training_set = generate_PINNs_training_set(N0, Nb, Nf, lb, ub, path=path)
|
||||
|
||||
n = PINNs(layers, lb, ub)
|
||||
opt = nn.Adam(n.trainable_params(), learning_rate=lr)
|
||||
loss = PINNs_loss(N0, Nb, Nf)
|
||||
|
||||
#call back configuration
|
||||
loss_print_num = 1 # print loss per loss_print_num epochs
|
||||
# save model
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=50)
|
||||
ckpoint = ModelCheckpoint(prefix="checkpoint_PINNs_Schrodinger", directory=ck_path, config=config_ck)
|
||||
|
||||
model = Model(network=n, loss_fn=loss, optimizer=opt)
|
||||
|
||||
model.train(epoch=epoch, train_dataset=training_set,
|
||||
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1)], dataset_sink_mode=True)
|
||||
print('Training complete')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Train PINNs')
|
||||
|
||||
#only support 'Schrodinger' for now
|
||||
parser.add_argument('--scenario', type=str, help='scenario for PINNs', default='Schrodinger')
|
||||
parser.add_argument('--datapath', type=str, help='path for dataset', default='')
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
pinns_scenario = args_opt.scenario
|
||||
data_path = args_opt.datapath
|
||||
if pinns_scenario == 'Schrodinger':
|
||||
conf = config.config_Sch
|
||||
if data_path != '':
|
||||
conf['path'] = data_path
|
||||
train_sch(**conf)
|
||||
else:
|
||||
print(f'{pinns_scenario} is not supported in PINNs training for now')
|
|
@ -22,4 +22,4 @@ subword-nmt>=0.3.7 # for st test
|
|||
sacrebleu>=1.4.14 # for st test
|
||||
sacremoses>=0.0.35 # for st test
|
||||
absl-py>=0.10.0 # for st test
|
||||
six>=1.15.0 # for st test
|
||||
six>=1.15.0 # for st test
|
Loading…
Reference in New Issue