From: @huchunmei
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-12 16:02:15 +08:00 committed by Gitee
commit c8ef2924a9
11 changed files with 24 additions and 51 deletions

View File

@ -6,13 +6,12 @@ checkpoint_url: ""
data_path: "/cache/data" data_path: "/cache/data"
output_path: "/cache/train" output_path: "/cache/train"
load_path: "/cache/checkpoint_path" load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/checkpoint_lenet-10_1875.ckpt'
device_target: Ascend device_target: Ascend
enable_profiling: False enable_profiling: False
data_path_local: '/data/hcm/data/MNIST_Data/' ckpt_path: '/cache/data/'
ckpt_path_local: '/data/hcm/data/ckpt_lenet/checkpoint_lenet-10_1875.ckpt' ckpt_file: '/cache/data/checkpoint_lenet-10_1875.ckpt'
# ============================================================================== # ==============================================================================
# Training options # Training options
num_classes: 10 num_classes: 10

View File

@ -19,27 +19,20 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
""" """
import os import os
# import sys from src.model_utils.config import config
# sys.path.append(os.path.join(os.getcwd(), 'utils')) from src.model_utils.moxing_adapter import moxing_wrapper
from utils.config import config from src.dataset import create_dataset
from utils.moxing_adapter import moxing_wrapper from src.lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.lenet import LeNet5
if os.path.exists(config.data_path_local):
config.data_path = config.data_path_local
ckpt_path = config.ckpt_path_local
else:
ckpt_path = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
def modelarts_process(): def modelarts_process():
pass config.ckpt_path = config.ckpt_file
@moxing_wrapper(pre_process=modelarts_process) @moxing_wrapper(pre_process=modelarts_process)
def eval_lenet(): def eval_lenet():
@ -53,7 +46,7 @@ def eval_lenet():
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(config.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
ds_eval = create_dataset(os.path.join(config.data_path, "test"), ds_eval = create_dataset(os.path.join(config.data_path, "test"),
config.batch_size, config.batch_size,

View File

@ -14,23 +14,15 @@
# ============================================================================ # ============================================================================
"""export checkpoint file into air, onnx, mindir models""" """export checkpoint file into air, onnx, mindir models"""
import os from src.model_utils.config import config
# import sys from src.model_utils.device_adapter import get_device_id
# sys.path.append(os.path.join(os.getcwd(), 'utils')) from src.lenet import LeNet5
from utils.config import config
from utils.device_adapter import get_device_id
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.lenet import LeNet5
if os.path.exists(config.data_path_local):
ckpt_file = config.ckpt_path_local
else:
ckpt_file = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend": if config.device_target == "Ascend":
context.set_context(device_id=get_device_id()) context.set_context(device_id=get_device_id())
@ -40,7 +32,7 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5(config.num_classes) network = LeNet5(config.num_classes)
# load network checkpoint # load network checkpoint
param_dict = load_checkpoint(ckpt_file) param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# export network # export network

View File

@ -17,7 +17,6 @@
# an simple tutorial as follows, more parameters can be setting # an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
# DATA_PATH=$1 DATA_PATH=$1
# CKPT_PATH=$2 CKPT_PATH=$2
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH python -s ${self_path}/../eval.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log_eval.txt 2>&1 &
python -s ${self_path}/../eval.py > log_eval.txt 2>&1 &

View File

@ -17,7 +17,6 @@
# an simple tutorial as follows, more parameters can be setting # an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0") script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}") self_path=$(dirname "${script_self}")
# DATA_PATH=$1 DATA_PATH=$1
# CKPT_PATH=$2 CKPT_PATH=$2
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH python -s ${self_path}/../train.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
python -s ${self_path}/../train.py > log.txt 2>&1 &

View File

@ -115,7 +115,7 @@ def get_config():
""" """
parser = argparse.ArgumentParser(description="default name", add_help=False) parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path") help="Config file path")
path_args, _ = parser.parse_known_args() path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path) default, helper, choices = parse_yaml(path_args.config_path)

View File

@ -19,14 +19,11 @@ python train.py --data_path /YourDataPath
""" """
import os import os
# import sys from src.model_utils.config import config
# sys.path.append(os.path.join(os.getcwd(), 'utils')) from src.model_utils.moxing_adapter import moxing_wrapper
from utils.config import config
from utils.moxing_adapter import moxing_wrapper
from utils.device_adapter import get_rank_id
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lenet import LeNet5 from src.lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -36,12 +33,6 @@ from mindspore.common import set_seed
set_seed(1) set_seed(1)
if os.path.exists(config.data_path_local):
config.data_path = config.data_path_local
config.checkpoint_path = os.path.join(config.checkpoint_path, str(get_rank_id()))
else:
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path, str(get_rank_id()))
def modelarts_pre_process(): def modelarts_pre_process():
pass pass
@ -59,7 +50,7 @@ def train_lenet():
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max) keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.checkpoint_path, config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
if config.device_target != "Ascend": if config.device_target != "Ascend":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})