forked from mindspore-Ecosystem/mindspore
!9382 add ocean model GOMO
From: @wangmin0104 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
7a5c39ea1d
|
@ -1,6 +1,5 @@
|
|||
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png)
|
||||
|
||||
|
||||
# Welcome to the Model Zoo for MindSpore
|
||||
|
||||
In order to facilitate developers to enjoy the benefits of MindSpore framework, we will continue to add typical networks and some of the related pre-trained models. If you have needs for the model zoo, you can file an issue on [gitee](https://gitee.com/mindspore/mindspore/issues) or [MindSpore](https://bbs.huaweicloud.com/forum/forum-1076-1.html), We will consider it in time.
|
||||
|
@ -11,8 +10,6 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
|
||||
- Officially maintained and supported
|
||||
|
||||
|
||||
|
||||
# Table of Contents
|
||||
|
||||
- [Official](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official)
|
||||
|
@ -30,13 +27,13 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [MobileNetV2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2/README.md)
|
||||
- [MobileNetV2_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv2_quant/README.md)
|
||||
- [MobileNetV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv3/README.md)
|
||||
- [InceptionV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv3/README.md)
|
||||
- [InceptionV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv3/README.md)
|
||||
- [Object Detection and Segmentation](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv)
|
||||
- [DeepLabV3](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3/README.md)
|
||||
- [FasterRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/faster_rcnn/README.md)
|
||||
- [YoloV3-DarkNet53](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53/README.md)
|
||||
- [YoloV3-DarkNet53](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_darknet53/README.md)
|
||||
- [YoloV3-ResNet18](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/yolov3_resnet18/README.md)
|
||||
- [MaskRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/maskrcnn/README.md)
|
||||
- [MaskRCNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/maskrcnn/README.md)
|
||||
- [SSD](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ssd/README.md)
|
||||
- [Warp-CTC](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/warpctc/README.md)
|
||||
|
||||
|
@ -44,53 +41,49 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
|
|||
- [BERT[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/bert/README.md)
|
||||
- [TinyBERT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/tinybert/README.md)
|
||||
- [GNMT V2](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/gnmt_v2/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [LSTM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm/README.md)
|
||||
- [MASS](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/mass/README.md)
|
||||
- [Transformer](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/transformer/README.md)
|
||||
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend)
|
||||
- [DeepFM](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/deepfm/README.md)
|
||||
- [Wide&Deep[benchmark]](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/recommend/wide_and_deep/README.md)
|
||||
- [Graph Neural Networks](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
- [BGCF](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf/README.md)
|
||||
- [GAT](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gat/README.md)
|
||||
- [GCN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gcn//README.md)
|
||||
|
||||
- [Research](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research)
|
||||
- [Computer Vision](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv)
|
||||
- [Research](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research)
|
||||
- [Computer Vision](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv)
|
||||
- [GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet/README.md)
|
||||
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md)
|
||||
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
|
||||
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
|
||||
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md)
|
||||
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
|
||||
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
|
||||
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
|
||||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
|
||||
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
|
||||
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
|
||||
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
|
||||
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
|
||||
- [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)
|
||||
|
||||
|
||||
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
|
||||
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)
|
||||
- [GOMO](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ocean_model/README.md)
|
||||
|
||||
- [Community](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/community)
|
||||
|
||||
# Announcements
|
||||
|
||||
| Date | News |
|
||||
| ------------ | ------------------------------------------------------------ |
|
||||
| September 25, 2020 | Support [MindSpore v1.0.0](https://www.mindspore.cn/news/newschildren/en?id=262) |
|
||||
| September 01, 2020 | Support [MindSpore v0.7.0-beta](https://www.mindspore.cn/news/newschildren/en?id=246) |
|
||||
| July 31, 2020 | Support [MindSpore v0.6.0-beta](https://www.mindspore.cn/news/newschildren/en?id=237) |
|
||||
|
||||
|
||||
|
||||
# Disclaimers
|
||||
|
||||
Mindspore only provides scripts that downloads and preprocesses public datasets. We do not own these datasets and are not responsible for their quality or maintenance. Please make sure you have permission to use the dataset under the dataset’s license. The models trained on these dataset are for non-commercial research and educational purpose only.
|
||||
Mindspore only provides scripts that downloads and preprocesses public datasets. We do not own these datasets and are not responsible for their quality or maintenance. Please make sure you have permission to use the dataset under the dataset’s license. The models trained on these dataset are for non-commercial research and educational purpose only.
|
||||
|
||||
To dataset owners: we will remove or update all public content upon request if you don’t want your dataset included on Mindspore, or wish to update it in any way. Please contact us through a Github/Gitee issue. Your understanding and contribution to this community is greatly appreciated.
|
||||
|
||||
MindSpore is Apache 2.0 licensed. Please see the LICENSE file.
|
||||
|
||||
|
||||
|
||||
# License
|
||||
|
||||
[Apache License 2.0](https://gitee.com/mindspore/mindspore/blob/master/LICENSE)
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
# GOMO Example
|
||||
|
||||
- [Description](#Description)
|
||||
- [Model Architecture](#Model-Architecture)
|
||||
- [Dataset](#Dataset)
|
||||
- [Environment Requirements](#Environment-Requirements)
|
||||
- [Quick Start](#Quick-Start)
|
||||
- [Script Description](#Script-Description)
|
||||
- [Script and Sample Code](#Script-and-Sample-Code)
|
||||
- [Training Process](#Training-Process)
|
||||
- [Model Description](#Model-Description)
|
||||
- [Evaluation Performance](#Evaluation-Performance)
|
||||
- [Description of Random Situation](#Description-of-Random-Situation)
|
||||
- [ModelZoo Homepage](#ModelZoo-Homepage)
|
||||
|
||||
## Description
|
||||
|
||||
Generalized Operator Modelling of the Ocean (GOMO) is a three-dimensional ocean model based on OpenArray which is a simple operator library for the decoupling of ocean modelling and parallel computing (Xiaomeng Huang et al, 2019). GOMO is a numerical solution model using finite differential algorithm to solve PDE equations. With MindSpore and GPU, we can achieve great improvments in solving those PDE equations compared with CPU.
|
||||
This is an example of training GOMO Model with MindSpore on GPU.
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The overall model architecture of GOMO is show below:[link](https://gmd.copernicus.org/articles/12/4729/2019/gmd-12-4729-2019-discussion.html). The fundamental equations and algorithms of GOMO can also be found in this article
|
||||
|
||||
## Dataset
|
||||
|
||||
Dataset used: Seamount
|
||||
|
||||
- Dataset size: 65x49x21
|
||||
|
||||
- Data format:nc
|
||||
|
||||
- Download the dataset
|
||||
|
||||
> download the GOMO from Github and you can find the seamount dataset file in the `GOMO/bin/data` directory.
|
||||
|
||||
## Environment Requirements
|
||||
|
||||
- Hardware: GPU
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## Quick Start
|
||||
|
||||
After installing MindSpore via the official website, you can start training as follows:
|
||||
|
||||
```shell
|
||||
# run distributed training example
|
||||
sh run_distribute_train.sh [im] [jm] [kb] [step] [DATASET_PATH]
|
||||
```
|
||||
|
||||
## Script Description
|
||||
|
||||
### Script and Sample Code
|
||||
|
||||
```shell
|
||||
└── ocean_model
|
||||
├── README.md # descriptions about ocean model GOMO
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # launch distributed training for GPU
|
||||
├──src
|
||||
│ ├── GOMO.py # GOMO model
|
||||
│ ├── Grid.py # grid initial
|
||||
│ ├── stencil.py # averaging and differential stencil oprator
|
||||
│ ├── op_operator.py # averaging and differential kernel operator
|
||||
│ ├── read_var.py # read variables from nc file
|
||||
├── train.py # train script
|
||||
```
|
||||
|
||||
### Training Process
|
||||
|
||||
```shell
|
||||
sh run_distribute_train.sh [im] [jm] [kb] [step] [DATASET_PATH]
|
||||
```
|
||||
|
||||
Training result will be stored in the current path, whose folder name begins with "train".
|
||||
|
||||
## Model Description
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | GPU |
|
||||
| -------------------------- |---------------------------------- |
|
||||
| Resource | GPU(Tesla V100 SXM2),Memory 16G
|
||||
| uploaded Date |
|
||||
| MindSpore Version |
|
||||
| Dataset | Seamount
|
||||
| Training Parameters | step=10, im=65, km=49, kb=21
|
||||
| Outputs | numpy file
|
||||
| Speed | 17 ms/step
|
||||
| Total time | 3 mins
|
||||
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc/ocean_model)
|
||||
|
||||
## Description of Random Situation
|
||||
|
||||
## ModelZoo HomePage
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_gpu.sh [im] [jm] [kb] [step] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
mkdir ./outputs
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --im=$1 --jm=$2 --kb=$3 --step=$4 --file_path=$5 --outputs_path="./outputs/" &> log &
|
||||
|
||||
cd ..
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Grid"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from src.stencil import AXB, AYB, AZB
|
||||
|
||||
|
||||
class Grid(nn.Cell):
|
||||
"""
|
||||
init C grid
|
||||
"""
|
||||
def __init__(self, im, jm, km, stencil_width=1):
|
||||
super(Grid, self).__init__()
|
||||
self.im = im
|
||||
self.jm = jm
|
||||
self.km = km
|
||||
self.x_map = [1, 0, 3, 2, 5, 4, 7, 6]
|
||||
self.y_map = [2, 3, 0, 1, 6, 7, 4, 5]
|
||||
self.z_map = [4, 5, 6, 7, 0, 1, 2, 3]
|
||||
self.AXB = AXB(stencil_width=stencil_width)
|
||||
self.AYB = AYB(stencil_width=stencil_width)
|
||||
self.AZB = AZB(stencil_width=stencil_width)
|
||||
|
||||
def construct(self, dx, dy, dz):
|
||||
"""construct"""
|
||||
dx0 = self.AYB(self.AXB(dx))
|
||||
dy0 = self.AYB(self.AXB(dy))
|
||||
dz0 = dz
|
||||
|
||||
dx1 = self.AYB(dx)
|
||||
dy1 = self.AYB(dy)
|
||||
dz1 = self.AZB(dz)
|
||||
|
||||
dx2 = self.AXB(dx)
|
||||
dy2 = self.AXB(dy)
|
||||
|
||||
dx3 = dx
|
||||
dy3 = dy
|
||||
|
||||
x_d = (dx0, dx1, dx2, dx3)
|
||||
y_d = (dy0, dy1, dy2, dy3)
|
||||
z_d = (dz0, dz1)
|
||||
|
||||
return x_d, y_d, z_d
|
|
@ -0,0 +1,210 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""stencil operations kernel"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class axb_kernel(nn.Cell):
|
||||
"""create axb_kernel"""
|
||||
def __init__(self):
|
||||
super(axb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((1, 0), (0, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class ayb_kernel(nn.Cell):
|
||||
"""create ayb_kernel"""
|
||||
def __init__(self):
|
||||
super(ayb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (1, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class azb_kernel(nn.Cell):
|
||||
"""create azb_kernel"""
|
||||
def __init__(self):
|
||||
super(azb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class axf_kernel(nn.Cell):
|
||||
"""create axf_kernel"""
|
||||
def __init__(self):
|
||||
super(axf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 1), (0, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (1, 0, 0), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class ayf_kernel(nn.Cell):
|
||||
"""create ayf_kernel"""
|
||||
def __init__(self):
|
||||
super(ayf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 1), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 1, 0), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class azf_kernel(nn.Cell):
|
||||
"""create azf_kernel"""
|
||||
def __init__(self):
|
||||
super(azf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (0, 1)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 1), x_shape)
|
||||
out = 0.5 * (x + x1)
|
||||
return out
|
||||
|
||||
|
||||
class dxb_kernel(nn.Cell):
|
||||
"""create dxb_kernel"""
|
||||
def __init__(self):
|
||||
super(dxb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((1, 0), (0, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
x = x - x1
|
||||
return x
|
||||
|
||||
|
||||
class dxf_kernel(nn.Cell):
|
||||
"""create dxf_kernel"""
|
||||
def __init__(self):
|
||||
super(dxf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 1), (0, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (1, 0, 0), x_shape)
|
||||
x = x1 - x
|
||||
return x
|
||||
|
||||
|
||||
class dyb_kernel(nn.Cell):
|
||||
"""create dyb_kernel"""
|
||||
def __init__(self):
|
||||
super(dyb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (1, 0), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
x = x - x1
|
||||
return x
|
||||
|
||||
|
||||
class dyf_kernel(nn.Cell):
|
||||
"""create dyf_kernel"""
|
||||
def __init__(self):
|
||||
super(dyf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 1), (0, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 1, 0), x_shape)
|
||||
x = x1 - x
|
||||
return x
|
||||
|
||||
|
||||
class dzb_kernel(nn.Cell):
|
||||
"""create dzb_kernel"""
|
||||
def __init__(self):
|
||||
super(dzb_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (1, 0)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 0), x_shape)
|
||||
x = x - x1
|
||||
return x
|
||||
|
||||
|
||||
class dzf_kernel(nn.Cell):
|
||||
"""create dzf_kernel"""
|
||||
def __init__(self):
|
||||
super(dzf_kernel, self).__init__()
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (0, 1)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.slice(x1, (0, 0, 1), x_shape)
|
||||
x = x1 - x
|
||||
return x
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""read variables"""
|
||||
|
||||
import numpy as np
|
||||
import netCDF4 as nc
|
||||
|
||||
|
||||
# variable name list
|
||||
params_name = ['z', 'zz', 'dz', 'dzz', 'dx', 'dy', 'cor', 'h', 'fsm', 'dum', 'dvm', 'art', 'aru', 'arv', 'rfe', 'rfw',
|
||||
'rfn', 'rfs', 'east_e', 'north_e', 'east_c', 'north_c', 'east_u', 'north_u', 'east_v', 'north_v', 'tb',
|
||||
'sb', 'tclim', 'sclim', 'rot', 'vfluxf', 'wusurf', 'wvsurf', 'e_atmos', 'ub', 'vb', 'uab', 'vab', 'elb',
|
||||
'etb', 'dt', 'uabw', 'uabe', 'vabs', 'vabn', 'els', 'eln', 'ele', 'elw', 'ssurf', 'tsurf', 'tbe', 'sbe',
|
||||
'sbw', 'tbw', 'tbn', 'tbs', 'sbn', 'sbs', 'wtsurf', 'swrad']
|
||||
|
||||
|
||||
def load_var(file_obj, name):
|
||||
"""load variable from nc data file"""
|
||||
data = file_obj.variables[name]
|
||||
data = data[:]
|
||||
data = np.float32(np.transpose(data, (2, 1, 0)))
|
||||
return data
|
||||
|
||||
|
||||
def read_nc(file_path):
|
||||
""" put the load variable into the dict """
|
||||
variable = {}
|
||||
file_obj = nc.Dataset(file_path)
|
||||
for name in params_name:
|
||||
variable[name] = load_var(file_obj, name)
|
||||
return variable
|
|
@ -0,0 +1,392 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""stencil operations"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from src.oa_operator import axb_kernel, axf_kernel, ayb_kernel, ayf_kernel, azb_kernel, azf_kernel
|
||||
from src.oa_operator import dxb_kernel, dxf_kernel, dyb_kernel, dyf_kernel, dzb_kernel, dzf_kernel
|
||||
|
||||
class AXB(nn.Cell):
|
||||
"""
|
||||
backward averaging operation along x direction
|
||||
output = (input[i, j, k] + input[i-1, j, k]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(1, 0, 0), rb=(0, 0, 0)):
|
||||
super(AXB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.axb_kernel = axb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.axb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class AXF(nn.Cell):
|
||||
"""
|
||||
forward averaging operation along x direction
|
||||
output = (input[i, j, k] + input[i+1, j, k]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(1, 0, 0)):
|
||||
super(AXF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.axf_kernel = axf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.axf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class AYB(nn.Cell):
|
||||
"""
|
||||
backward averaging operation along y direction
|
||||
output = (input[i, j, k] + input[i, j-1, k]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 1, 0), rb=(0, 0, 0)):
|
||||
super(AYB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.ayb_kernel = ayb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.ayb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class AYF(nn.Cell):
|
||||
"""
|
||||
forward averaging operation along y direction
|
||||
output = (input[i, j, k] + input[i, j+1, k]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(0, 1, 0)):
|
||||
super(AYF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.ayf_kernel = ayf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.ayf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class AZB(nn.Cell):
|
||||
"""
|
||||
backward averaging operation along z direction
|
||||
output = (input[i, j, k] + input[i, j, k-1]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 1), rb=(0, 0, 0)):
|
||||
super(AZB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.azb_kernel = azb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.azb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class AZF(nn.Cell):
|
||||
"""
|
||||
forward averaging operation along z direction
|
||||
output = (input[i, j, k] + input[i, j, k+1]) / 2
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(0, 0, 1)):
|
||||
super(AZF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.azf_kernel = azf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.azf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DXB(nn.Cell):
|
||||
"""
|
||||
backward differential operation along x direction
|
||||
output = input[i, j, k] - input[i-1, j, k]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(1, 0, 0), rb=(0, 0, 0)):
|
||||
super(DXB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dxb_kernel = dxb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dxb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DXF(nn.Cell):
|
||||
"""
|
||||
forward differential operation along x direction
|
||||
output = input[i+1, j, k] - input[i, j, k]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(1, 0, 0)):
|
||||
super(DXF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dxf_kernel = dxf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dxf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DYB(nn.Cell):
|
||||
"""
|
||||
backward differential operation along y direction
|
||||
output = input[i, j, k] - input[i, j-1, k]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 1, 0), rb=(0, 0, 0)):
|
||||
super(DYB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dyb_kernel = dyb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dyb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DYF(nn.Cell):
|
||||
"""
|
||||
forward differential operation along y direction
|
||||
output = input[i, j+1, k] - input[i, j, k]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(0, 1, 0)):
|
||||
super(DYF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dyf_kernel = dyf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dyf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DZB(nn.Cell):
|
||||
"""
|
||||
backward differential operation along z direction
|
||||
output = input[i, j, k] - input[i, j, k-1]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 1), rb=(0, 0, 0)):
|
||||
super(DZB, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dzb_kernel = dzb_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dzb_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
||||
|
||||
|
||||
class DZF(nn.Cell):
|
||||
"""
|
||||
forward differential operation along z direction
|
||||
output = input[i, j, k+1] - input[i, j, k]
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - The input should be a 3-dimension tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape is the same as the input.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
def __init__(self, stencil_width=1, lb=(0, 0, 0), rb=(0, 0, 1)):
|
||||
super(DZF, self).__init__()
|
||||
self.stencil_width = stencil_width
|
||||
self.pad = P.Pad(((self.stencil_width, self.stencil_width), (self.stencil_width, self.stencil_width),
|
||||
(self.stencil_width, self.stencil_width)))
|
||||
self.slice = P.Slice()
|
||||
self.shape = P.Shape()
|
||||
self.dzf_kernel = dzf_kernel()
|
||||
|
||||
def construct(self, x):
|
||||
x1 = self.pad(x)
|
||||
x_shape = self.shape(x)
|
||||
x1 = self.dzf_kernel(x1)
|
||||
x1 = self.slice(x1, (self.stencil_width, self.stencil_width, self.stencil_width), x_shape)
|
||||
return x1
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from src.read_var import read_nc
|
||||
from src.GOMO import GOMO_init, GOMO, read_init
|
||||
|
||||
parser = argparse.ArgumentParser(description='GOMO')
|
||||
parser.add_argument('--file_path', type=str, default=None, help='file path')
|
||||
parser.add_argument('--outputs_path', type=str, default=None, help='outputs path')
|
||||
parser.add_argument('--im', type=int, default=65, help='im size')
|
||||
parser.add_argument('--jm', type=int, default=49, help='jm size')
|
||||
parser.add_argument('--kb', type=int, default=21, help='kb size')
|
||||
parser.add_argument('--stencil_width', type=int, default=1, help='stencil width')
|
||||
parser.add_argument('--step', type=int, default=10, help='time step')
|
||||
args_gomo = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False, enable_graph_kernel=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
variable = read_nc(args_gomo.file_path)
|
||||
im = args_gomo.im
|
||||
jm = args_gomo.jm
|
||||
kb = args_gomo.kb
|
||||
stencil_width = args_gomo.stencil_width
|
||||
|
||||
# variable init
|
||||
dx, dy, dz, uab, vab, elb, etb, sb, tb, ub, vb, dt, h, w, wubot, wvbot, vfluxb, utb, vtb, dhb, egb, vfluxf, z, zz, \
|
||||
dzz, cor, fsm = read_init(
|
||||
variable, im, jm, kb)
|
||||
|
||||
# define grid and init variable update
|
||||
net_init = GOMO_init(im, jm, kb, stencil_width)
|
||||
ua, va, el, et, etf, d, dt, l, q2b, q2lb, kh, km, kq, aam, w, q2, q2l, t, s, u, v, cbc, rmean, rho, x_d, y_d, z_d\
|
||||
= net_init(dx, dy, dz, uab, vab, elb, etb, sb, tb, ub, vb, h, w, vfluxf, zz, fsm)
|
||||
|
||||
# define GOMO model
|
||||
Model = GOMO(im=im, jm=jm, kb=kb, stencil_width=stencil_width, variable=variable, x_d=x_d, y_d=y_d, z_d=z_d,
|
||||
q2b=q2b, q2lb=q2lb, aam=aam, cbc=cbc, rmean=rmean)
|
||||
|
||||
# time step of GOMO Model
|
||||
for step in range(1, args_gomo.step+1):
|
||||
elf, etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho, wubot, wvbot, ub, vb, \
|
||||
egb, etb, dt, dhb, utb, vtb, vfluxb, et, steps, vamax, q2b, q2lb = Model(
|
||||
etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho,
|
||||
wubot, wvbot, ub, vb, egb, etb, dt, dhb, utb, vtb, vfluxb, et)
|
||||
vars_list = etf, ua, uab, va, vab, el, elb, d, u, v, w, kq, km, kh, q2, q2l, tb, t, sb, s, rho, wubot, wvbot, \
|
||||
ub, vb, egb, etb, dt, dhb, utb, vtb, vfluxb, et
|
||||
for var in vars_list:
|
||||
var.asnumpy()
|
||||
# save output
|
||||
if step % 5 == 0:
|
||||
np.save(args_gomo.outputs_path + "u_"+str(step)+".npy", u.asnumpy())
|
||||
np.save(args_gomo.outputs_path + "v_" + str(step) + ".npy", v.asnumpy())
|
||||
np.save(args_gomo.outputs_path + "t_" + str(step) + ".npy", t.asnumpy())
|
||||
np.save(args_gomo.outputs_path + "et_" + str(step) + ".npy", et.asnumpy())
|
Loading…
Reference in New Issue