diff --git a/model_zoo/official/cv/lenet/default_config.yaml b/model_zoo/official/cv/lenet/default_config.yaml index b6d7ecb3631..fe686ee22f4 100644 --- a/model_zoo/official/cv/lenet/default_config.yaml +++ b/model_zoo/official/cv/lenet/default_config.yaml @@ -6,13 +6,12 @@ checkpoint_url: "" data_path: "/cache/data" output_path: "/cache/train" load_path: "/cache/checkpoint_path" -checkpoint_path: './checkpoint/' -checkpoint_file: './checkpoint/checkpoint_lenet-10_1875.ckpt' device_target: Ascend enable_profiling: False -data_path_local: '/data/hcm/data/MNIST_Data/' -ckpt_path_local: '/data/hcm/data/ckpt_lenet/checkpoint_lenet-10_1875.ckpt' +ckpt_path: '/cache/data/' +ckpt_file: '/cache/data/checkpoint_lenet-10_1875.ckpt' + # ============================================================================== # Training options num_classes: 10 diff --git a/model_zoo/official/cv/lenet/eval.py b/model_zoo/official/cv/lenet/eval.py index 11b6dd876cc..77a2a593581 100644 --- a/model_zoo/official/cv/lenet/eval.py +++ b/model_zoo/official/cv/lenet/eval.py @@ -19,27 +19,20 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt """ import os -# import sys -# sys.path.append(os.path.join(os.getcwd(), 'utils')) -from utils.config import config -from utils.moxing_adapter import moxing_wrapper +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper +from src.dataset import create_dataset +from src.lenet import LeNet5 import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train import Model from mindspore.nn.metrics import Accuracy -from src.dataset import create_dataset -from src.lenet import LeNet5 -if os.path.exists(config.data_path_local): - config.data_path = config.data_path_local - ckpt_path = config.ckpt_path_local -else: - ckpt_path = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt') def modelarts_process(): - pass + config.ckpt_path = config.ckpt_file @moxing_wrapper(pre_process=modelarts_process) def eval_lenet(): @@ -53,7 +46,7 @@ def eval_lenet(): model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Testing ==============") - param_dict = load_checkpoint(ckpt_path) + param_dict = load_checkpoint(config.ckpt_path) load_param_into_net(network, param_dict) ds_eval = create_dataset(os.path.join(config.data_path, "test"), config.batch_size, diff --git a/model_zoo/official/cv/lenet/export.py b/model_zoo/official/cv/lenet/export.py index 6c8ce0a89bf..ce0ef6ee6ab 100644 --- a/model_zoo/official/cv/lenet/export.py +++ b/model_zoo/official/cv/lenet/export.py @@ -14,23 +14,15 @@ # ============================================================================ """export checkpoint file into air, onnx, mindir models""" -import os -# import sys -# sys.path.append(os.path.join(os.getcwd(), 'utils')) -from utils.config import config -from utils.device_adapter import get_device_id +from src.model_utils.config import config +from src.model_utils.device_adapter import get_device_id +from src.lenet import LeNet5 import numpy as np import mindspore from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export -from src.lenet import LeNet5 -if os.path.exists(config.data_path_local): - ckpt_file = config.ckpt_path_local -else: - ckpt_file = os.path.join(config.data_path, 'checkpoint_lenet-10_1875.ckpt') - context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target) if config.device_target == "Ascend": context.set_context(device_id=get_device_id()) @@ -40,7 +32,7 @@ if __name__ == "__main__": # define fusion network network = LeNet5(config.num_classes) # load network checkpoint - param_dict = load_checkpoint(ckpt_file) + param_dict = load_checkpoint(config.ckpt_file) load_param_into_net(network, param_dict) # export network diff --git a/model_zoo/official/cv/lenet/scripts/run_standalone_eval_ascend.sh b/model_zoo/official/cv/lenet/scripts/run_standalone_eval_ascend.sh index 0c95dfcfa93..df8c73475ce 100755 --- a/model_zoo/official/cv/lenet/scripts/run_standalone_eval_ascend.sh +++ b/model_zoo/official/cv/lenet/scripts/run_standalone_eval_ascend.sh @@ -17,7 +17,6 @@ # an simple tutorial as follows, more parameters can be setting script_self=$(readlink -f "$0") self_path=$(dirname "${script_self}") -# DATA_PATH=$1 -# CKPT_PATH=$2 -# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH -python -s ${self_path}/../eval.py > log_eval.txt 2>&1 & +DATA_PATH=$1 +CKPT_PATH=$2 +python -s ${self_path}/../eval.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log_eval.txt 2>&1 & diff --git a/model_zoo/official/cv/lenet/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/lenet/scripts/run_standalone_train_ascend.sh index 9884cb97be1..713192120ae 100755 --- a/model_zoo/official/cv/lenet/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/official/cv/lenet/scripts/run_standalone_train_ascend.sh @@ -17,7 +17,6 @@ # an simple tutorial as follows, more parameters can be setting script_self=$(readlink -f "$0") self_path=$(dirname "${script_self}") -# DATA_PATH=$1 -# CKPT_PATH=$2 -# --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH -python -s ${self_path}/../train.py > log.txt 2>&1 & \ No newline at end of file +DATA_PATH=$1 +CKPT_PATH=$2 +python -s ${self_path}/../train.py --data_path=$DATA_PATH --device_target="Ascend" --ckpt_path=$CKPT_PATH > log.txt 2>&1 & diff --git a/model_zoo/official/cv/lenet/utils/__init__.py b/model_zoo/official/cv/lenet/src/model_utils/__init__.py similarity index 100% rename from model_zoo/official/cv/lenet/utils/__init__.py rename to model_zoo/official/cv/lenet/src/model_utils/__init__.py diff --git a/model_zoo/official/cv/lenet/utils/config.py b/model_zoo/official/cv/lenet/src/model_utils/config.py similarity index 98% rename from model_zoo/official/cv/lenet/utils/config.py rename to model_zoo/official/cv/lenet/src/model_utils/config.py index 2c191e9f748..7f1ff6e2b8d 100644 --- a/model_zoo/official/cv/lenet/utils/config.py +++ b/model_zoo/official/cv/lenet/src/model_utils/config.py @@ -115,7 +115,7 @@ def get_config(): """ parser = argparse.ArgumentParser(description="default name", add_help=False) current_dir = os.path.dirname(os.path.abspath(__file__)) - parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../default_config.yaml"), + parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"), help="Config file path") path_args, _ = parser.parse_known_args() default, helper, choices = parse_yaml(path_args.config_path) diff --git a/model_zoo/official/cv/lenet/utils/device_adapter.py b/model_zoo/official/cv/lenet/src/model_utils/device_adapter.py similarity index 100% rename from model_zoo/official/cv/lenet/utils/device_adapter.py rename to model_zoo/official/cv/lenet/src/model_utils/device_adapter.py diff --git a/model_zoo/official/cv/lenet/utils/local_adapter.py b/model_zoo/official/cv/lenet/src/model_utils/local_adapter.py similarity index 100% rename from model_zoo/official/cv/lenet/utils/local_adapter.py rename to model_zoo/official/cv/lenet/src/model_utils/local_adapter.py diff --git a/model_zoo/official/cv/lenet/utils/moxing_adapter.py b/model_zoo/official/cv/lenet/src/model_utils/moxing_adapter.py similarity index 100% rename from model_zoo/official/cv/lenet/utils/moxing_adapter.py rename to model_zoo/official/cv/lenet/src/model_utils/moxing_adapter.py diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 2d5be9a4474..2da0665df51 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -19,14 +19,11 @@ python train.py --data_path /YourDataPath """ import os -# import sys -# sys.path.append(os.path.join(os.getcwd(), 'utils')) -from utils.config import config -from utils.moxing_adapter import moxing_wrapper -from utils.device_adapter import get_rank_id - +from src.model_utils.config import config +from src.model_utils.moxing_adapter import moxing_wrapper from src.dataset import create_dataset from src.lenet import LeNet5 + import mindspore.nn as nn from mindspore import context from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor @@ -36,12 +33,6 @@ from mindspore.common import set_seed set_seed(1) -if os.path.exists(config.data_path_local): - config.data_path = config.data_path_local - config.checkpoint_path = os.path.join(config.checkpoint_path, str(get_rank_id())) -else: - config.checkpoint_path = os.path.join(config.output_path, config.checkpoint_path, str(get_rank_id())) - def modelarts_pre_process(): pass @@ -59,7 +50,7 @@ def train_lenet(): time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, keep_checkpoint_max=config.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.checkpoint_path, config=config_ck) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=config.ckpt_path, config=config_ck) if config.device_target != "Ascend": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})