From: @huchunmei
Reviewed-by: @oacjiewen,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-12 16:02:15 +08:00 committed by Gitee
commit c8ef2924a9
11 changed files with 24 additions and 51 deletions

View File

@ -6,13 +6,12 @@ checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
checkpoint_path: './checkpoint/'
checkpoint_file: './checkpoint/checkpoint_lenet-10_1875.ckpt'
device_target: Ascend
enable_profiling: False
data_path_local: '/data/hcm/data/MNIST_Data/'
ckpt_path_local: '/data/hcm/data/ckpt_lenet/checkpoint_lenet-10_1875.ckpt'
ckpt_path: '/cache/data/'
ckpt_file: '/cache/data/checkpoint_lenet-10_1875.ckpt'
# ==============================================================================
# Training options
num_classes: 10

View File

@ -19,27 +19,20 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
"""
import os
# import sys
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
from utils.config import config
from utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.dataset import create_dataset
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.lenet import LeNet5
if os.path.exists(config.data_path_local):
config.data_path = config.data_path_local
ckpt_path = config.ckpt_path_local
else:
ckpt_path = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
def modelarts_process():
pass
config.ckpt_path = config.ckpt_file
@moxing_wrapper(pre_process=modelarts_process)
def eval_lenet():
@ -53,7 +46,7 @@ def eval_lenet():
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
param_dict = load_checkpoint(ckpt_path)
param_dict = load_checkpoint(config.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset(os.path.join(config.data_path, "test"),
config.batch_size,

View File

@ -14,23 +14,15 @@
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import os
# import sys
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
from utils.config import config
from utils.device_adapter import get_device_id
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.lenet import LeNet5
import numpy as np
import mindspore
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.lenet import LeNet5
if os.path.exists(config.data_path_local):
ckpt_file = config.ckpt_path_local
else:
ckpt_file = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
@ -40,7 +32,7 @@ if __name__ == "__main__":
# define fusion network
network = LeNet5(config.num_classes)
# load network checkpoint
param_dict = load_checkpoint(ckpt_file)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(network, param_dict)
# export network

View File

@ -17,7 +17,6 @@
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
# DATA_PATH=$1
# CKPT_PATH=$2
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
python -s ${self_path}/../eval.py > log_eval.txt 2>&1 &
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../eval.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log_eval.txt 2>&1 &

View File

@ -17,7 +17,6 @@
# an simple tutorial as follows, more parameters can be setting
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
# DATA_PATH=$1
# CKPT_PATH=$2
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
python -s ${self_path}/../train.py > log.txt 2>&1 &
DATA_PATH=$1
CKPT_PATH=$2
python -s ${self_path}/../train.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

View File

@ -115,7 +115,7 @@ def get_config():
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)

View File

@ -19,14 +19,11 @@ python train.py --data_path /YourDataPath
"""
import os
# import sys
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
from utils.config import config
from utils.moxing_adapter import moxing_wrapper
from utils.device_adapter import get_rank_id
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.dataset import create_dataset
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
@ -36,12 +33,6 @@ from mindspore.common import set_seed
set_seed(1)
if os.path.exists(config.data_path_local):
config.data_path = config.data_path_local
config.checkpoint_path = os.path.join(config.checkpoint_path, str(get_rank_id()))
else:
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path, str(get_rank_id()))
def modelarts_pre_process():
pass
@ -59,7 +50,7 @@ def train_lenet():
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.checkpoint_path, config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
if config.device_target != "Ascend":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})