forked from mindspore-Ecosystem/mindspore
modify pinns network scripts
This commit is contained in:
parent
901124c9bd
commit
b0c3af6f10
|
@ -104,7 +104,7 @@ After installing MindSpore via the official website, you can start training and
|
||||||
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
bash scripts/run_distribute_train.sh [RANK_TABLE_FILE] [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
|
||||||
|
|
||||||
#enter the path ,run Makefile
|
#enter the path ,run Makefile
|
||||||
cd ./src/ETSNET/pse/;make
|
cd ./src/PSENET/pse/;make
|
||||||
|
|
||||||
#run test.py
|
#run test.py
|
||||||
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
|
python test.py --ckpt pretrained_model.ckpt --TEST_ROOT_DIR [test root path]
|
||||||
|
@ -453,7 +453,7 @@ time_cb = TimeMonitor(data_size=step_size)
|
||||||
loss_cb = LossCallBack(per_print_times=20)
|
loss_cb = LossCallBack(per_print_times=20)
|
||||||
# set and apply parameters of check point
|
# set and apply parameters of check point
|
||||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
ckpoint_cb = ModelCheckpoint(prefix="PSENet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||||
|
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||||
|
|
|
@ -105,7 +105,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64
|
||||||
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
|
bash scripts/run_distribute_train.sh [RANK_FILE] [PRETRAINED_PATH] [TRAIN_ROOT_DIR]
|
||||||
|
|
||||||
# 进入路径,运行Makefile
|
# 进入路径,运行Makefile
|
||||||
cd ./src/ETSNET/pse/;make clean&&make
|
cd ./src/PSENET/pse/;make clean&&make
|
||||||
|
|
||||||
# 运行test.py
|
# 运行test.py
|
||||||
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
|
python test.py --ckpt [CKPK_PATH] --TEST_ROOT_DIR [TEST_DATA_DIR]
|
||||||
|
@ -391,7 +391,7 @@ loss_cb = LossCallBack(per_print_times=20)
|
||||||
|
|
||||||
# 设置并应用检查点参数
|
# 设置并应用检查点参数
|
||||||
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=2)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
ckpoint_cb = ModelCheckpoint(prefix="PSENet", config=ckpoint_cf, directory=config.TRAIN_MODEL_SAVE_PATH)
|
||||||
|
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=False, callbacks=[time_cb, loss_cb, ckpoint_cb])
|
||||||
|
|
|
@ -13,12 +13,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""hub config."""
|
"""hub config."""
|
||||||
from src.ETSNET.etsnet import ETSNet
|
from src.PSENET.psenet import PSENet
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
def create_network(name, *args, **kwargs):
|
def create_network(name, *args, **kwargs):
|
||||||
if name == "psenet":
|
if name == "psenet":
|
||||||
infer_mode = kwargs.get("infer_mode", False)
|
infer_mode = kwargs.get("infer_mode", False)
|
||||||
config.INFERENCE = infer_mode
|
config.INFERENCE = infer_mode
|
||||||
return ETSNet(config)
|
return PSENet(config)
|
||||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||||
|
|
|
@ -83,7 +83,7 @@ def modelarts_pre_process():
|
||||||
cmake_command = 'cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL=/usr/local ..&&make -j16&&sudo make install'
|
cmake_command = 'cmake -D CMAKE_BUILD_TYPE=Release -D CMAKE_INSTALL=/usr/local ..&&make -j16&&sudo make install'
|
||||||
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
os.system('cd {}/opencv-3.4.9&&mkdir build&&cd ./build&&{}'.format(local_path, cmake_command))
|
||||||
|
|
||||||
os.system('cd {}/src/ETSNET/pse&&make clean&&make'.format(local_path))
|
os.system('cd {}/src/PSENET/pse&&make clean&&make'.format(local_path))
|
||||||
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh'.format(local_path))
|
os.system('cd {}&&sed -i ’s/\r//‘ scripts/run_eval_ascend.sh'.format(local_path))
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ def modelarts_post_process():
|
||||||
|
|
||||||
@moxing_wrapper(pre_process=modelarts_pre_process, post_process=modelarts_post_process)
|
@moxing_wrapper(pre_process=modelarts_pre_process, post_process=modelarts_post_process)
|
||||||
def test():
|
def test():
|
||||||
from src.ETSNET.pse import pse
|
from src.PSENET.pse import pse
|
||||||
|
|
||||||
local_path = ""
|
local_path = ""
|
||||||
if config.enable_modelarts:
|
if config.enable_modelarts:
|
||||||
|
|
|
@ -13,16 +13,59 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Train PINNs for Navier-Stokes equation scenario"""
|
"""Train PINNs for Navier-Stokes equation scenario"""
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore import Model, context, nn
|
from mindspore import Model, context, nn
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||||
ModelCheckpoint, TimeMonitor)
|
ModelCheckpoint, TimeMonitor, Callback)
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.NavierStokes.dataset import generate_training_set_navier_stokes
|
from src.NavierStokes.dataset import generate_training_set_navier_stokes
|
||||||
from src.NavierStokes.loss import PINNs_loss_navier
|
from src.NavierStokes.loss import PINNs_loss_navier
|
||||||
from src.NavierStokes.net import PINNs_navier
|
from src.NavierStokes.net import PINNs_navier
|
||||||
|
|
||||||
|
|
||||||
|
class EvalCallback(Callback):
|
||||||
|
"""eval callback."""
|
||||||
|
def __init__(self, data_path, ckpt_dir, per_eval_epoch, num_neuron=20):
|
||||||
|
super(EvalCallback, self).__init__()
|
||||||
|
if not isinstance(per_eval_epoch, int) or per_eval_epoch <= 0:
|
||||||
|
raise ValueError("per_eval_epoch must be int and > 0")
|
||||||
|
layers = [3, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron, num_neuron,
|
||||||
|
num_neuron, 2]
|
||||||
|
_, lb, ub = generate_training_set_navier_stokes(10, 10, data_path, 0)
|
||||||
|
self.network = PINNs_navier(layers, lb, ub)
|
||||||
|
self.ckpt_dir = ckpt_dir
|
||||||
|
self.per_eval_epoch = per_eval_epoch
|
||||||
|
self.best_result = None
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
"""epoch end function."""
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
cur_epoch = cb_params.cur_epoch_num
|
||||||
|
batch_num = cb_params.batch_num
|
||||||
|
if cur_epoch % self.per_eval_epoch == 0:
|
||||||
|
ckpt_format = os.path.join(self.ckpt_dir,
|
||||||
|
"checkpoint_PINNs_NavierStokes*-{}_{}.ckpt".format(cur_epoch, batch_num))
|
||||||
|
ckpt_list = glob.glob(ckpt_format)
|
||||||
|
if not ckpt_list:
|
||||||
|
raise ValueError("can not find {}".format(ckpt_format))
|
||||||
|
ckpt_name = sorted(ckpt_list)[-1]
|
||||||
|
print("the latest ckpt_name is", ckpt_name)
|
||||||
|
param_dict = load_checkpoint(ckpt_name)
|
||||||
|
load_param_into_net(self.network, param_dict)
|
||||||
|
lambda1_pred = self.network.lambda1.asnumpy()
|
||||||
|
lambda2_pred = self.network.lambda2.asnumpy()
|
||||||
|
error1 = np.abs(lambda1_pred - 1.0) * 100
|
||||||
|
error2 = np.abs(lambda2_pred - 0.01) / 0.01 * 100
|
||||||
|
print(f'Error of lambda 1 is {error1[0]:.6f}%')
|
||||||
|
print(f'Error of lambda 2 is {error2[0]:.6f}%')
|
||||||
|
if self.best_result is None or error1 + error2 < self.best_result:
|
||||||
|
self.best_result = error1 + error2
|
||||||
|
shutil.copyfile(ckpt_name, os.path.join(self.ckpt_dir, "best_result.ckpt"))
|
||||||
|
|
||||||
def train_navier(epoch, lr, batch_size, n_train, path, noise, num_neuron, ck_path, seed=None):
|
def train_navier(epoch, lr, batch_size, n_train, path, noise, num_neuron, ck_path, seed=None):
|
||||||
"""
|
"""
|
||||||
Train PINNs for Navier-Stokes equation
|
Train PINNs for Navier-Stokes equation
|
||||||
|
@ -57,9 +100,10 @@ def train_navier(epoch, lr, batch_size, n_train, path, noise, num_neuron, ck_pat
|
||||||
# save model
|
# save model
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=20)
|
config_ck = CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=20)
|
||||||
ckpoint = ModelCheckpoint(prefix="checkpoint_PINNs_NavierStokes", directory=ck_path, config=config_ck)
|
ckpoint = ModelCheckpoint(prefix="checkpoint_PINNs_NavierStokes", directory=ck_path, config=config_ck)
|
||||||
|
eval_cb = EvalCallback(data_path=path, ckpt_dir=ck_path, per_eval_epoch=100)
|
||||||
|
|
||||||
model = Model(network=n, loss_fn=loss, optimizer=opt)
|
model = Model(network=n, loss_fn=loss, optimizer=opt)
|
||||||
|
|
||||||
model.train(epoch=epoch, train_dataset=training_set,
|
model.train(epoch=epoch, train_dataset=training_set,
|
||||||
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1)], dataset_sink_mode=True)
|
callbacks=[LossMonitor(loss_print_num), ckpoint, TimeMonitor(1), eval_cb], dataset_sink_mode=True)
|
||||||
print('Training complete')
|
print('Training complete')
|
||||||
|
|
Loading…
Reference in New Issue