!23345 modify pinns network scripts
Merge pull request !23345 from anzhengqi/modify-pinns
This commit is contained in:
commit
23fb972eae
|
@ -104,7 +104,7 @@ After installing MindSpore via the official website, you can start training and
|
|||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||
|
||||
#enter the path ,run Makefile
|
||||
cd ./src/ETSNET/pse/;make
|
||||
cd ./src/PSENET/pse/;make
|
||||
|
||||
#run test.py
|
||||
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
|
||||
|
@ -453,7 +453,7 @@ time_cb = TimeMonitor(data_size=step_size)
|
|||
loss_cb = LossCallBack(per_print_times=20)
|
||||
# set and apply parameters of check point
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="PSENet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||
|
|
|
@ -105,7 +105,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64
|
|||
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
|
||||
|
||||
# 进入路径,运行Makefile
|
||||
cd ./src/ETSNET/pse/;make clean&&make
|
||||
cd ./src/PSENET/pse/;make clean&&make
|
||||
|
||||
# 运行test.py
|
||||
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
|
||||
|
@ -391,7 +391,7 @@ loss_cb = LossCallBack(per_print_times=20)
|
|||
|
||||
# 设置并应用检查点参数
|
||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="PSENet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||
|
||||
model = Model(net)
|
||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.ETSNET.etsnet import ETSNet
|
||||
from src.PSENET.psenet import PSENet
|
||||
from src.config import config
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "psenet":
|
||||
infer_mode = kwargs.get("infer_mode", False)
|
||||
config.INFERENCE = infer_mode
|
||||
return ETSNet(config)
|
||||
return PSENet(config)
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -83,7 +83,7 @@ def modelarts_pre_process():
|
|||
cmake_command = 'cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL=/usr/local ..&&make -j16&&sudo make install'
|
||||
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
||||
|
||||
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
||||
os.system('cd {}/src/PSENET/pse&&make clean&&make'.format(local_path))
|
||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh'.format(local_path))
|
||||
|
||||
|
||||
|
@ -94,7 +94,7 @@ def modelarts_post_process():
|
|||
|
||||
@moxing_wrapper(pre_process=modelarts_pre_process, post_process=modelarts_post_process)
|
||||
def test():
|
||||
from src.ETSNET.pse import pse
|
||||
from src.PSENET.pse import pse
|
||||
|
||||
local_path = ""
|
||||
if config.enable_modelarts:
|
||||
|
|
|
@ -13,16 +13,59 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Train PINNs for Navier-Stokes equation scenario"""
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import numpy as np
|
||||
from mindspore import Model, context, nn
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||
ModelCheckpoint, TimeMonitor)
|
||||
ModelCheckpoint, TimeMonitor, Callback)
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.NavierStokes.dataset import generate_training_set_navier_stokes
|
||||
from src.NavierStokes.loss import PINNs_loss_navier
|
||||
from src.NavierStokes.net import PINNs_navier
|
||||
|
||||
|
||||
class EvalCallback(Callback):
|
||||
"""eval callback."""
|
||||
def __init__(self, data_path, ckpt_dir, per_eval_epoch, num_neuron=20):
|
||||
super(EvalCallback, self).__init__()
|
||||
if not isinstance(per_eval_epoch, int) or per_eval_epoch <= 0:
|
||||
raise ValueError("per_eval_epoch must be int and > 0")
|
||||
layers = [3, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron,
|
||||
num_neuron, 2]
|
||||
_, lb, ub = generate_training_set_navier_stokes(10, 10, data_path, 0)
|
||||
self.network = PINNs_navier(layers, lb, ub)
|
||||
self.ckpt_dir = ckpt_dir
|
||||
self.per_eval_epoch = per_eval_epoch
|
||||
self.best_result = None
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""epoch end function."""
|
||||
cb_params = run_context.original_args()
|
||||
cur_epoch = cb_params.cur_epoch_num
|
||||
batch_num = cb_params.batch_num
|
||||
if cur_epoch % self.per_eval_epoch == 0:
|
||||
ckpt_format = os.path.join(self.ckpt_dir,
|
||||
"checkpoint_PINNs_NavierStokes*-{}_{}.ckpt".format(cur_epoch, batch_num))
|
||||
ckpt_list = glob.glob(ckpt_format)
|
||||
if not ckpt_list:
|
||||
raise ValueError("can not find {}".format(ckpt_format))
|
||||
ckpt_name = sorted(ckpt_list)[-1]
|
||||
print("the latest ckpt_name is", ckpt_name)
|
||||
param_dict = load_checkpoint(ckpt_name)
|
||||
load_param_into_net(self.network, param_dict)
|
||||
lambda1_pred = self.network.lambda1.asnumpy()
|
||||
lambda2_pred = self.network.lambda2.asnumpy()
|
||||
error1 = np.abs(lambda1_pred - 1.0) * 100
|
||||
error2 = np.abs(lambda2_pred - 0.01) / 0.01 * 100
|
||||
print(f'Error of lambda 1 is {error1[0]:.6f}%')
|
||||
print(f'Error of lambda 2 is {error2[0]:.6f}%')
|
||||
if self.best_result is None or error1 + error2 < self.best_result:
|
||||
self.best_result = error1 + error2
|
||||
shutil.copyfile(ckpt_name, os.path.join(self.ckpt_dir, "best_result.ckpt"))
|
||||
|
||||
def train_navier(epoch, lr, batch_size, n_train, path, noise, num_neuron, ck_path, seed=None):
|
||||
"""
|
||||
Train PINNs for Navier-Stokes equation
|
||||
|
@ -57,9 +100,10 @@ def train_navier(epoch, lr, batch_size, n_train, path, noise, num_neuron, ck_pat
|
|||
# save model
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=20)
|
||||
ckpoint = ModelCheckpoint(prefix="checkpoint_PINNs_NavierStokes", directory=ck_path, config=config_ck)
|
||||
eval_cb = EvalCallback(data_path=path, ckpt_dir=ck_path, per_eval_epoch=100)
|
||||
|
||||
model = Model(network=n, loss_fn=loss, optimizer=opt)
|
||||
|
||||
model.train(epoch=epoch, train_dataset=training_set,
|
||||
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1)], dataset_sink_mode=True)
|
||||
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1), eval_cb], dataset_sink_mode=True)
|
||||
print('Training complete')
|
||||
|
|
Loading…
Reference in New Issue