forked from mindspore-Ecosystem/mindspore
checkpoint add model_type
This commit is contained in:
parent
8870956954
commit
d3f9b80066
|
@ -593,6 +593,17 @@ def check_bool(input_param):
|
||||||
raise TypeError("Input type must be bool!")
|
raise TypeError("Input type must be bool!")
|
||||||
|
|
||||||
|
|
||||||
|
def check_string(input_param, valid_values):
|
||||||
|
"""String type judgment."""
|
||||||
|
if isinstance(input_param, str) and input_param in valid_values:
|
||||||
|
return input_param
|
||||||
|
if len(valid_values) == 1:
|
||||||
|
raise ValueError(f'Input should be str and must be {valid_values[0]},'
|
||||||
|
f' but got {input_param}.')
|
||||||
|
raise ValueError(f'Input should be str and must be one of {valid_values},'
|
||||||
|
f' but got {input_param}.')
|
||||||
|
|
||||||
|
|
||||||
def check_input_format(input_param):
|
def check_input_format(input_param):
|
||||||
"""Judge input format."""
|
"""Judge input format."""
|
||||||
if input_param == "NCHW":
|
if input_param == "NCHW":
|
||||||
|
|
|
@ -22,6 +22,7 @@ message Checkpoint {
|
||||||
required TensorProto tensor = 2;
|
required TensorProto tensor = 2;
|
||||||
}
|
}
|
||||||
repeated Value value = 1;
|
repeated Value value = 1;
|
||||||
|
required string model_type = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,17 +21,16 @@ import time
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
|
||||||
from mindspore.train._utils import _make_directory
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||||
|
|
||||||
from ._callback import Callback, set_cur_net
|
from ._callback import Callback, set_cur_net
|
||||||
|
|
||||||
|
|
||||||
_cur_dir = os.getcwd()
|
_cur_dir = os.getcwd()
|
||||||
_save_dir = _cur_dir
|
_save_dir = _cur_dir
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _check_file_name_prefix(file_name_prefix):
|
def _check_file_name_prefix(file_name_prefix):
|
||||||
"""
|
"""
|
||||||
Check file name valid or not.
|
Check file name valid or not.
|
||||||
|
@ -87,6 +86,7 @@ class CheckpointConfig:
|
||||||
Can't be used with keep_checkpoint_max at the same time.
|
Can't be used with keep_checkpoint_max at the same time.
|
||||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||||
|
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input_param is None or 0.
|
ValueError: If the input_param is None or 0.
|
||||||
|
@ -101,7 +101,8 @@ class CheckpointConfig:
|
||||||
save_checkpoint_seconds=0,
|
save_checkpoint_seconds=0,
|
||||||
keep_checkpoint_max=5,
|
keep_checkpoint_max=5,
|
||||||
keep_checkpoint_per_n_minutes=0,
|
keep_checkpoint_per_n_minutes=0,
|
||||||
integrated_save=True):
|
integrated_save=True,
|
||||||
|
model_type="normal"):
|
||||||
|
|
||||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||||
|
@ -115,6 +116,8 @@ class CheckpointConfig:
|
||||||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
||||||
if keep_checkpoint_per_n_minutes:
|
if keep_checkpoint_per_n_minutes:
|
||||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
||||||
|
if model_type:
|
||||||
|
model_type = check_string(model_type, ["normal", "fusion", "quant"])
|
||||||
|
|
||||||
self._save_checkpoint_steps = save_checkpoint_steps
|
self._save_checkpoint_steps = save_checkpoint_steps
|
||||||
self._save_checkpoint_seconds = save_checkpoint_seconds
|
self._save_checkpoint_seconds = save_checkpoint_seconds
|
||||||
|
@ -129,6 +132,7 @@ class CheckpointConfig:
|
||||||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||||
self._keep_checkpoint_max = 1
|
self._keep_checkpoint_max = 1
|
||||||
|
|
||||||
|
self._model_type = model_type
|
||||||
self._integrated_save = check_bool(integrated_save)
|
self._integrated_save = check_bool(integrated_save)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -156,12 +160,18 @@ class CheckpointConfig:
|
||||||
"""Get the value of _integrated_save."""
|
"""Get the value of _integrated_save."""
|
||||||
return self._integrated_save
|
return self._integrated_save
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_type(self):
|
||||||
|
"""Get the value of model_type."""
|
||||||
|
return self._model_type
|
||||||
|
|
||||||
def get_checkpoint_policy(self):
|
def get_checkpoint_policy(self):
|
||||||
"""Get the policy of checkpoint."""
|
"""Get the policy of checkpoint."""
|
||||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||||
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
||||||
'keep_checkpoint_max': self._keep_checkpoint_max,
|
'keep_checkpoint_max': self._keep_checkpoint_max,
|
||||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
|
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
|
||||||
|
'model_type': self._model_type}
|
||||||
|
|
||||||
return checkpoint_policy
|
return checkpoint_policy
|
||||||
|
|
||||||
|
@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
|
||||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||||
_save_graph(cb_params.train_network, graph_file_name)
|
_save_graph(cb_params.train_network, graph_file_name)
|
||||||
self._graph_saved = True
|
self._graph_saved = True
|
||||||
self._save_ckpt(cb_params)
|
self._save_ckpt(cb_params, self._config.model_type)
|
||||||
|
|
||||||
def end(self, run_context):
|
def end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
|
||||||
"""
|
"""
|
||||||
cb_params = run_context.original_args()
|
cb_params = run_context.original_args()
|
||||||
_to_save_last_ckpt = True
|
_to_save_last_ckpt = True
|
||||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
|
||||||
|
|
||||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||||
destroy_allgather_cell()
|
destroy_allgather_cell()
|
||||||
|
@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
|
||||||
"""Save checkpoint files."""
|
"""Save checkpoint files."""
|
||||||
if cb_params.cur_step_num == self._last_triggered_step:
|
if cb_params.cur_step_num == self._last_triggered_step:
|
||||||
return
|
return
|
||||||
|
@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
|
||||||
set_cur_net(cb_params.train_network)
|
set_cur_net(cb_params.train_network)
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
cb_params.train_network.exec_checkpoint_graph()
|
||||||
|
|
||||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
|
||||||
|
|
||||||
if os.path.exists(gen_file):
|
if os.path.exists(gen_file):
|
||||||
shutil.move(gen_file, cur_file)
|
shutil.move(gen_file, cur_file)
|
||||||
|
|
|
@ -76,7 +76,7 @@ class LossMonitor(Callback):
|
||||||
step_loss = np.mean(step_loss.asnumpy())
|
step_loss = np.mean(step_loss.asnumpy())
|
||||||
|
|
||||||
self.losses.append(step_loss)
|
self.losses.append(step_loss)
|
||||||
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num)
|
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1
|
||||||
|
|
||||||
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
|
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
|
||||||
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
|
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
|
||||||
|
@ -87,7 +87,7 @@ class LossMonitor(Callback):
|
||||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||||
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
||||||
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
|
cb_params.cur_epoch_num, cb_params.epoch_num,
|
||||||
cur_step_in_epoch, int(cb_params.batch_num),
|
cur_step_in_epoch, int(cb_params.batch_num),
|
||||||
step_loss, np.mean(self.losses),
|
step_loss, np.mean(self.losses),
|
||||||
step_mseconds), flush=True)
|
step_mseconds), flush=True)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.common.api import _executor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore._checkparam import check_input_data
|
from mindspore._checkparam import check_input_data
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
||||||
|
|
||||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||||
|
@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
||||||
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||||
|
|
||||||
|
ModelType = ["normal", "fusion", "quant"]
|
||||||
|
|
||||||
|
|
||||||
def _special_process_par(par, new_par):
|
def _special_process_par(par, new_par):
|
||||||
"""
|
"""
|
||||||
|
@ -101,20 +104,22 @@ def _update_param(param, new_param):
|
||||||
param.set_parameter_data(type(param.data)(new_param.data))
|
param.set_parameter_data(type(param.data)(new_param.data))
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(parameter_list, ckpoint_file_name):
|
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint info to a specified file.
|
Saves checkpoint info to a specified file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter_list (list): Parameters list, each element is a dict
|
parameter_list (list): Parameters list, each element is a dict
|
||||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||||
ckpoint_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
|
model_type (str): The name of model type. Default: "normal".
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: Failed to save the Checkpoint file.
|
RuntimeError: Failed to save the Checkpoint file.
|
||||||
"""
|
"""
|
||||||
logger.info("Execute save checkpoint process.")
|
logger.info("Execute save checkpoint process.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = Checkpoint()
|
||||||
|
checkpoint_list.model_type = model_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for param in parameter_list:
|
for param in parameter_list:
|
||||||
|
@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
||||||
for dim in param['data'].shape:
|
for dim in param['data'].shape:
|
||||||
param_tensor.dims.append(dim)
|
param_tensor.dims.append(dim)
|
||||||
|
|
||||||
with open(ckpoint_file_name, "wb") as f:
|
with open(ckpt_file_name, "wb") as f:
|
||||||
f.write(checkpoint_list.SerializeToString())
|
f.write(checkpoint_list.SerializeToString())
|
||||||
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
||||||
raise RuntimeError(e.__str__())
|
raise RuntimeError(e.__str__())
|
||||||
logger.info("Save checkpoint process finish.")
|
logger.info("Save checkpoint process finish.")
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckpoint_file_name, net=None):
|
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
||||||
"""
|
"""
|
||||||
Loads checkpoint info from a specified file.
|
Loads checkpoint info from a specified file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpoint_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
|
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||||
net (Cell): Cell network. Default: None
|
net (Cell): Cell network. Default: None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Checkpoint file is incorrect.
|
ValueError: Checkpoint file is incorrect.
|
||||||
"""
|
"""
|
||||||
if not isinstance(ckpoint_file_name, str):
|
if not isinstance(ckpt_file_name, str):
|
||||||
raise ValueError("The ckpoint_file_name must be String.")
|
raise ValueError("The ckpt_file_name must be string.")
|
||||||
|
|
||||||
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
|
if model_type not in ModelType:
|
||||||
|
raise ValueError(f"The model_type is not in {ModelType}.")
|
||||||
|
|
||||||
|
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
||||||
raise ValueError("Please input the correct checkpoint file name.")
|
raise ValueError("Please input the correct checkpoint file name.")
|
||||||
|
|
||||||
if os.path.getsize(ckpoint_file_name) == 0:
|
if os.path.getsize(ckpt_file_name) == 0:
|
||||||
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
||||||
|
|
||||||
logger.info("Execute load checkpoint process.")
|
logger.info("Execute load checkpoint process.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = Checkpoint()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(ckpoint_file_name, "rb") as f:
|
with open(ckpt_file_name, "rb") as f:
|
||||||
pb_content = f.read()
|
pb_content = f.read()
|
||||||
checkpoint_list.ParseFromString(pb_content)
|
checkpoint_list.ParseFromString(pb_content)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
|
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
|
||||||
raise ValueError(e.__str__())
|
raise ValueError(e.__str__())
|
||||||
|
|
||||||
parameter_dict = {}
|
parameter_dict = {}
|
||||||
|
if model_type != checkpoint_list.model_type:
|
||||||
|
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
|
||||||
|
checkpoint_list.model_type, model_type))
|
||||||
try:
|
try:
|
||||||
for element in checkpoint_list.value:
|
for element in checkpoint_list.value:
|
||||||
data = element.tensor.tensor_content
|
data = element.tensor.tensor_content
|
||||||
|
@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
||||||
logger.info("Load checkpoint process finish.")
|
logger.info("Load checkpoint process finish.")
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||||
raise RuntimeError(e.__str__())
|
raise RuntimeError(e.__str__())
|
||||||
|
|
||||||
if net:
|
if net:
|
||||||
|
@ -303,14 +314,15 @@ def _save_graph(network, file_name):
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
|
||||||
|
|
||||||
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint for 'ms' backend.
|
Saves checkpoint for 'ms' backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_network (Network): The train network for training.
|
train_network (Network): The train network for training.
|
||||||
ckpoint_file_name (str): The name of checkpoint file.
|
ckpt_file_name (str): The name of checkpoint file.
|
||||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
|
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||||
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_dict = {}
|
param_dict = {}
|
||||||
|
@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
|
||||||
each_param["data"] = param_data
|
each_param["data"] = param_data
|
||||||
param_list.append(each_param)
|
param_list.append(each_param)
|
||||||
|
|
||||||
save_checkpoint(param_list, ckpoint_file_name)
|
save_checkpoint(param_list, ckpt_file_name, model_type)
|
||||||
|
|
||||||
|
|
||||||
def _get_merged_param_data(net, param_name, param_data):
|
def _get_merged_param_data(net, param_name, param_data):
|
||||||
|
|
|
@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
from src.dataset import create_dataset
|
|
||||||
from src.config import mnist_cfg as cfg
|
|
||||||
from src.lenet import LeNet5
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.config import mnist_cfg as cfg
|
||||||
|
from src.lenet import LeNet5
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||||
|
@ -49,9 +47,6 @@ if __name__ == "__main__":
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
repeat_size = cfg.epoch_size
|
repeat_size = cfg.epoch_size
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
|
|
|
@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
|
||||||
```bash
|
```bash
|
||||||
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
||||||
>>> ...
|
>>> ...
|
||||||
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||||
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||||
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||||
```
|
```
|
||||||
|
|
||||||
Also, you can just run this command instead.
|
Also, you can just run this command instead.
|
||||||
|
@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
|
||||||
```bash
|
```bash
|
||||||
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
||||||
>>> ...
|
>>> ...
|
||||||
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||||
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||||
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||||
```
|
```
|
||||||
|
|
||||||
### Evaluate quantization aware model
|
### Evaluate quantization aware model
|
||||||
|
@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
# convert funsion netwrok to quantization aware network
|
# convert funsion netwrok to quantization aware network
|
||||||
network = quant.convert_quant_network(network
|
network = quant.convert_quant_network(network)
|
||||||
```
|
```
|
||||||
|
|
||||||
Also, you can just run this command insread.
|
Also, you can just run this command insread.
|
||||||
|
|
|
@ -23,7 +23,6 @@ import argparse
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
|
@ -47,16 +46,18 @@ if __name__ == "__main__":
|
||||||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
||||||
step_size = ds_eval.get_dataset_size()
|
step_size = ds_eval.get_dataset_size()
|
||||||
|
|
||||||
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
# define loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
repeat_size = cfg.epoch_size
|
# define network optimization
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
# call back and monitor
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
# load check point into network
|
||||||
|
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
|
|
|
@ -23,7 +23,6 @@ import argparse
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.train.quant import quant
|
from mindspore.train.quant import quant
|
||||||
|
@ -48,20 +47,21 @@ if __name__ == "__main__":
|
||||||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
||||||
step_size = ds_eval.get_dataset_size()
|
step_size = ds_eval.get_dataset_size()
|
||||||
|
|
||||||
# define funsion network
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
# convert funsion netwrok to quantization aware network
|
# convert fusion netwrok to quantization aware network
|
||||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||||
|
|
||||||
|
# define loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
|
# define network optimization
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
# call back and monitor
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
|
|
|
@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.num_class = num_class
|
self.num_class = num_class
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(channel, 6, 5)
|
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
|
||||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||||
self.fc1 = nn.Dense(16 * 5 * 5, 120)
|
self.fc1 = nn.Dense(16 * 5 * 5, 120)
|
||||||
self.fc2 = nn.Dense(120, 84)
|
self.fc2 = nn.Dense(120, 84)
|
||||||
self.fc3 = nn.Dense(84, self.num_class)
|
self.fc3 = nn.Dense(84, self.num_class)
|
||||||
|
|
|
@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
|
||||||
|
|
||||||
def __init__(self, num_class=10, channel=1):
|
def __init__(self, num_class=10, channel=1):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
|
self.type = "fusion"
|
||||||
self.num_class = num_class
|
self.num_class = num_class
|
||||||
|
|
||||||
# change `nn.Conv2d` to `nn.Conv2dBnAct`
|
# change `nn.Conv2d` to `nn.Conv2dBnAct`
|
||||||
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu')
|
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
|
||||||
self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu')
|
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
|
||||||
# change `nn.Dense` to `nn.DenseBnAct`
|
# change `nn.Dense` to `nn.DenseBnAct`
|
||||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||||
|
|
|
@ -46,16 +46,24 @@ if __name__ == "__main__":
|
||||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||||
step_size = ds_train.get_dataset_size()
|
step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
# define network loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
|
# define network optimization
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||||
|
|
||||||
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
model_type=network.type)
|
||||||
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
|
# define model
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
print("============== Starting Training ==============")
|
print("============== Starting Training ==============")
|
||||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||||
dataset_sink_mode=args.dataset_sink_mode)
|
dataset_sink_mode=args.dataset_sink_mode)
|
||||||
print("============== End Training ==============")
|
print("============== End Training ==============")
|
||||||
|
|
|
@ -48,23 +48,30 @@ if __name__ == "__main__":
|
||||||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||||
step_size = ds_train.get_dataset_size()
|
step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
# define funsion network
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path)
|
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
# convert funsion netwrok to quantization aware network
|
# convert fusion network to quantization aware network
|
||||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||||
|
|
||||||
|
# define network loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
|
# define network optimization
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||||
|
|
||||||
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
model_type="quant")
|
||||||
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
|
# define model
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
print("============== Starting Training ==============")
|
print("============== Starting Training ==============")
|
||||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||||
dataset_sink_mode=args.dataset_sink_mode)
|
dataset_sink_mode=args.dataset_sink_mode)
|
||||||
print("============== End Training ==============")
|
print("============== End Training ==============")
|
||||||
|
|
|
@ -85,7 +85,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
is_ckpt_exist = os.path.exists(ckpt_file_path)
|
is_ckpt_exist = os.path.exists(ckpt_file_path)
|
||||||
if is_ckpt_exist:
|
if is_ckpt_exist:
|
||||||
param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path)
|
param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
export(net, input_data, file_name=model_path_name, file_format='LITE')
|
export(net, input_data, file_name=model_path_name, file_format='LITE')
|
||||||
print("test lenet predict success.")
|
print("test lenet predict success.")
|
||||||
|
|
|
@ -111,19 +111,19 @@ def test_save_checkpoint():
|
||||||
os.chmod('./parameters.ckpt', stat.S_IWRITE)
|
os.chmod('./parameters.ckpt', stat.S_IWRITE)
|
||||||
os.remove('./parameters.ckpt')
|
os.remove('./parameters.ckpt')
|
||||||
|
|
||||||
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||||
save_checkpoint(parameter_list, ckpoint_file_name)
|
save_checkpoint(parameter_list, ckpt_file_name)
|
||||||
|
|
||||||
|
|
||||||
def test_load_checkpoint_error_filename():
|
def test_load_checkpoint_error_filename():
|
||||||
ckpoint_file_name = 1
|
ckpt_file_name = 1
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
load_checkpoint(ckpoint_file_name)
|
load_checkpoint(ckpt_file_name)
|
||||||
|
|
||||||
|
|
||||||
def test_load_checkpoint():
|
def test_load_checkpoint():
|
||||||
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||||
par_dict = load_checkpoint(ckpoint_file_name)
|
par_dict = load_checkpoint(ckpt_file_name)
|
||||||
|
|
||||||
assert len(par_dict) == 3
|
assert len(par_dict) == 3
|
||||||
assert par_dict['param_test'].name == 'param_test'
|
assert par_dict['param_test'].name == 'param_test'
|
||||||
|
@ -136,17 +136,17 @@ def test_checkpoint_manager():
|
||||||
""" test_checkpoint_manager """
|
""" test_checkpoint_manager """
|
||||||
ckp_mgr = _CheckpointManager()
|
ckp_mgr = _CheckpointManager()
|
||||||
|
|
||||||
ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
|
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
|
||||||
with open(ckpoint_file_name, 'w'):
|
with open(ckpt_file_name, 'w'):
|
||||||
os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
|
||||||
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
||||||
assert ckp_mgr.ckpoint_num == 1
|
assert ckp_mgr.ckpoint_num == 1
|
||||||
|
|
||||||
ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
|
ckp_mgr.remove_ckpoint_file(ckpt_file_name)
|
||||||
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
||||||
assert ckp_mgr.ckpoint_num == 0
|
assert ckp_mgr.ckpoint_num == 0
|
||||||
assert not os.path.exists(ckpoint_file_name)
|
assert not os.path.exists(ckpt_file_name)
|
||||||
|
|
||||||
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
|
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
|
||||||
another_file_name = os.path.realpath(another_file_name)
|
another_file_name = os.path.realpath(another_file_name)
|
||||||
|
@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
|
||||||
|
|
||||||
loss_net = WithLossCell(net, loss)
|
loss_net = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(loss_net, opt)
|
train_network = TrainOneStepCell(loss_net, opt)
|
||||||
_exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
|
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
||||||
|
|
||||||
load_checkpoint("new_ckpt.ckpt")
|
load_checkpoint("new_ckpt.ckpt")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue