forked from mindspore-Ecosystem/mindspore
!16205 lenet test
From: @huchunmei Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
c8ef2924a9
|
@ -6,13 +6,12 @@ checkpoint_url: ""
|
|||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
checkpoint_path: './checkpoint/'
|
||||
checkpoint_file: './checkpoint/checkpoint_lenet-10_1875.ckpt'
|
||||
device_target: Ascend
|
||||
enable_profiling: False
|
||||
|
||||
data_path_local: '/data/hcm/data/MNIST_Data/'
|
||||
ckpt_path_local: '/data/hcm/data/ckpt_lenet/checkpoint_lenet-10_1875.ckpt'
|
||||
ckpt_path: '/cache/data/'
|
||||
ckpt_file: '/cache/data/checkpoint_lenet-10_1875.ckpt'
|
||||
|
||||
# ==============================================================================
|
||||
# Training options
|
||||
num_classes: 10
|
||||
|
|
|
@ -19,27 +19,20 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
|||
"""
|
||||
|
||||
import os
|
||||
# import sys
|
||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
||||
from utils.config import config
|
||||
from utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.dataset import create_dataset
|
||||
from src.lenet import LeNet5
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from src.dataset import create_dataset
|
||||
from src.lenet import LeNet5
|
||||
|
||||
if os.path.exists(config.data_path_local):
|
||||
config.data_path = config.data_path_local
|
||||
ckpt_path = config.ckpt_path_local
|
||||
else:
|
||||
ckpt_path = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
|
||||
|
||||
def modelarts_process():
|
||||
pass
|
||||
config.ckpt_path = config.ckpt_file
|
||||
|
||||
@moxing_wrapper(pre_process=modelarts_process)
|
||||
def eval_lenet():
|
||||
|
@ -53,7 +46,7 @@ def eval_lenet():
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
param_dict = load_checkpoint(config.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
ds_eval = create_dataset(os.path.join(config.data_path, "test"),
|
||||
config.batch_size,
|
||||
|
|
|
@ -14,23 +14,15 @@
|
|||
# ============================================================================
|
||||
"""export checkpoint file into air, onnx, mindir models"""
|
||||
|
||||
import os
|
||||
# import sys
|
||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
||||
from utils.config import config
|
||||
from utils.device_adapter import get_device_id
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.device_adapter import get_device_id
|
||||
from src.lenet import LeNet5
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from src.lenet import LeNet5
|
||||
|
||||
|
||||
if os.path.exists(config.data_path_local):
|
||||
ckpt_file = config.ckpt_path_local
|
||||
else:
|
||||
ckpt_file = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||
if config.device_target == "Ascend":
|
||||
context.set_context(device_id=get_device_id())
|
||||
|
@ -40,7 +32,7 @@ if __name__ == "__main__":
|
|||
# define fusion network
|
||||
network = LeNet5(config.num_classes)
|
||||
# load network checkpoint
|
||||
param_dict = load_checkpoint(ckpt_file)
|
||||
param_dict = load_checkpoint(config.ckpt_file)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
# DATA_PATH=$1
|
||||
# CKPT_PATH=$2
|
||||
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
|
||||
python -s ${self_path}/../eval.py > log_eval.txt 2>&1 &
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../eval.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log_eval.txt 2>&1 &
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
# an simple tutorial as follows, more parameters can be setting
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
# DATA_PATH=$1
|
||||
# CKPT_PATH=$2
|
||||
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
|
||||
python -s ${self_path}/../train.py > log.txt 2>&1 &
|
||||
DATA_PATH=$1
|
||||
CKPT_PATH=$2
|
||||
python -s ${self_path}/../train.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
||||
|
|
|
@ -115,7 +115,7 @@ def get_config():
|
|||
"""
|
||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||
help="Config file path")
|
||||
path_args, _ = parser.parse_known_args()
|
||||
default, helper, choices = parse_yaml(path_args.config_path)
|
|
@ -19,14 +19,11 @@ python train.py --data_path /YourDataPath
|
|||
"""
|
||||
|
||||
import os
|
||||
# import sys
|
||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
||||
from utils.config import config
|
||||
from utils.moxing_adapter import moxing_wrapper
|
||||
from utils.device_adapter import get_rank_id
|
||||
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.dataset import create_dataset
|
||||
from src.lenet import LeNet5
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
@ -36,12 +33,6 @@ from mindspore.common import set_seed
|
|||
|
||||
set_seed(1)
|
||||
|
||||
if os.path.exists(config.data_path_local):
|
||||
config.data_path = config.data_path_local
|
||||
config.checkpoint_path = os.path.join(config.checkpoint_path, str(get_rank_id()))
|
||||
else:
|
||||
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path, str(get_rank_id()))
|
||||
|
||||
def modelarts_pre_process():
|
||||
pass
|
||||
|
||||
|
@ -59,7 +50,7 @@ def train_lenet():
|
|||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.checkpoint_path, config=config_ck)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
|
||||
|
||||
if config.device_target != "Ascend":
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
|
Loading…
Reference in New Issue