forked from mindspore-Ecosystem/mindspore
!16205 lenet test
From: @huchunmei Reviewed-by: @oacjiewen,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
c8ef2924a9
|
@ -6,13 +6,12 @@ checkpoint_url: ""
|
||||||
data_path: "/cache/data"
|
data_path: "/cache/data"
|
||||||
output_path: "/cache/train"
|
output_path: "/cache/train"
|
||||||
load_path: "/cache/checkpoint_path"
|
load_path: "/cache/checkpoint_path"
|
||||||
checkpoint_path: './checkpoint/'
|
|
||||||
checkpoint_file: './checkpoint/checkpoint_lenet-10_1875.ckpt'
|
|
||||||
device_target: Ascend
|
device_target: Ascend
|
||||||
enable_profiling: False
|
enable_profiling: False
|
||||||
|
|
||||||
data_path_local: '/data/hcm/data/MNIST_Data/'
|
ckpt_path: '/cache/data/'
|
||||||
ckpt_path_local: '/data/hcm/data/ckpt_lenet/checkpoint_lenet-10_1875.ckpt'
|
ckpt_file: '/cache/data/checkpoint_lenet-10_1875.ckpt'
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Training options
|
# Training options
|
||||||
num_classes: 10
|
num_classes: 10
|
||||||
|
|
|
@ -19,27 +19,20 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
# import sys
|
from src.model_utils.config import config
|
||||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from utils.config import config
|
from src.dataset import create_dataset
|
||||||
from utils.moxing_adapter import moxing_wrapper
|
from src.lenet import LeNet5
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from src.dataset import create_dataset
|
|
||||||
from src.lenet import LeNet5
|
|
||||||
|
|
||||||
if os.path.exists(config.data_path_local):
|
|
||||||
config.data_path = config.data_path_local
|
|
||||||
ckpt_path = config.ckpt_path_local
|
|
||||||
else:
|
|
||||||
ckpt_path = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
|
|
||||||
|
|
||||||
def modelarts_process():
|
def modelarts_process():
|
||||||
pass
|
config.ckpt_path = config.ckpt_file
|
||||||
|
|
||||||
@moxing_wrapper(pre_process=modelarts_process)
|
@moxing_wrapper(pre_process=modelarts_process)
|
||||||
def eval_lenet():
|
def eval_lenet():
|
||||||
|
@ -53,7 +46,7 @@ def eval_lenet():
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
param_dict = load_checkpoint(ckpt_path)
|
param_dict = load_checkpoint(config.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
ds_eval = create_dataset(os.path.join(config.data_path, "test"),
|
ds_eval = create_dataset(os.path.join(config.data_path, "test"),
|
||||||
config.batch_size,
|
config.batch_size,
|
||||||
|
|
|
@ -14,23 +14,15 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""export checkpoint file into air, onnx, mindir models"""
|
"""export checkpoint file into air, onnx, mindir models"""
|
||||||
|
|
||||||
import os
|
from src.model_utils.config import config
|
||||||
# import sys
|
from src.model_utils.device_adapter import get_device_id
|
||||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
from src.lenet import LeNet5
|
||||||
from utils.config import config
|
|
||||||
from utils.device_adapter import get_device_id
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore
|
import mindspore
|
||||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||||
from src.lenet import LeNet5
|
|
||||||
|
|
||||||
|
|
||||||
if os.path.exists(config.data_path_local):
|
|
||||||
ckpt_file = config.ckpt_path_local
|
|
||||||
else:
|
|
||||||
ckpt_file = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt')
|
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
|
||||||
if config.device_target == "Ascend":
|
if config.device_target == "Ascend":
|
||||||
context.set_context(device_id=get_device_id())
|
context.set_context(device_id=get_device_id())
|
||||||
|
@ -40,7 +32,7 @@ if __name__ == "__main__":
|
||||||
# define fusion network
|
# define fusion network
|
||||||
network = LeNet5(config.num_classes)
|
network = LeNet5(config.num_classes)
|
||||||
# load network checkpoint
|
# load network checkpoint
|
||||||
param_dict = load_checkpoint(ckpt_file)
|
param_dict = load_checkpoint(config.ckpt_file)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
# export network
|
# export network
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
# an simple tutorial as follows, more parameters can be setting
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
script_self=$(readlink -f "$0")
|
script_self=$(readlink -f "$0")
|
||||||
self_path=$(dirname "${script_self}")
|
self_path=$(dirname "${script_self}")
|
||||||
# DATA_PATH=$1
|
DATA_PATH=$1
|
||||||
# CKPT_PATH=$2
|
CKPT_PATH=$2
|
||||||
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
|
python -s ${self_path}/../eval.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log_eval.txt 2>&1 &
|
||||||
python -s ${self_path}/../eval.py > log_eval.txt 2>&1 &
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
# an simple tutorial as follows, more parameters can be setting
|
# an simple tutorial as follows, more parameters can be setting
|
||||||
script_self=$(readlink -f "$0")
|
script_self=$(readlink -f "$0")
|
||||||
self_path=$(dirname "${script_self}")
|
self_path=$(dirname "${script_self}")
|
||||||
# DATA_PATH=$1
|
DATA_PATH=$1
|
||||||
# CKPT_PATH=$2
|
CKPT_PATH=$2
|
||||||
# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH
|
python -s ${self_path}/../train.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
|
||||||
python -s ${self_path}/../train.py > log.txt 2>&1 &
|
|
||||||
|
|
|
@ -115,7 +115,7 @@ def get_config():
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
parser = argparse.ArgumentParser(description="default name", add_help=False)
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"),
|
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
|
||||||
help="Config file path")
|
help="Config file path")
|
||||||
path_args, _ = parser.parse_known_args()
|
path_args, _ = parser.parse_known_args()
|
||||||
default, helper, choices = parse_yaml(path_args.config_path)
|
default, helper, choices = parse_yaml(path_args.config_path)
|
|
@ -19,14 +19,11 @@ python train.py --data_path /YourDataPath
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
# import sys
|
from src.model_utils.config import config
|
||||||
# sys.path.append(os.path.join(os.getcwd(), 'utils'))
|
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||||
from utils.config import config
|
|
||||||
from utils.moxing_adapter import moxing_wrapper
|
|
||||||
from utils.device_adapter import get_rank_id
|
|
||||||
|
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.lenet import LeNet5
|
from src.lenet import LeNet5
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
@ -36,12 +33,6 @@ from mindspore.common import set_seed
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
if os.path.exists(config.data_path_local):
|
|
||||||
config.data_path = config.data_path_local
|
|
||||||
config.checkpoint_path = os.path.join(config.checkpoint_path, str(get_rank_id()))
|
|
||||||
else:
|
|
||||||
config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path, str(get_rank_id()))
|
|
||||||
|
|
||||||
def modelarts_pre_process():
|
def modelarts_pre_process():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -59,7 +50,7 @@ def train_lenet():
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.checkpoint_path, config=config_ck)
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck)
|
||||||
|
|
||||||
if config.device_target != "Ascend":
|
if config.device_target != "Ascend":
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
Loading…
Reference in New Issue