checkpoint add model_type
This commit is contained in:
parent
8870956954
commit
d3f9b80066
|
@ -593,6 +593,17 @@ def check_bool(input_param):
|
|||
raise TypeError("Input type must be bool!")
|
||||
|
||||
|
||||
def check_string(input_param, valid_values):
|
||||
"""String type judgment."""
|
||||
if isinstance(input_param, str) and input_param in valid_values:
|
||||
return input_param
|
||||
if len(valid_values) == 1:
|
||||
raise ValueError(f'Input should be str and must be {valid_values[0]},'
|
||||
f' but got {input_param}.')
|
||||
raise ValueError(f'Input should be str and must be one of {valid_values},'
|
||||
f' but got {input_param}.')
|
||||
|
||||
|
||||
def check_input_format(input_param):
|
||||
"""Judge input format."""
|
||||
if input_param == "NCHW":
|
||||
|
|
|
@ -22,6 +22,7 @@ message Checkpoint {
|
|||
required TensorProto tensor = 2;
|
||||
}
|
||||
repeated Value value = 1;
|
||||
required string model_type = 2;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -21,17 +21,16 @@ import time
|
|||
|
||||
import mindspore.context as context
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||
|
||||
from ._callback import Callback, set_cur_net
|
||||
|
||||
|
||||
_cur_dir = os.getcwd()
|
||||
_save_dir = _cur_dir
|
||||
|
||||
|
||||
|
||||
def _check_file_name_prefix(file_name_prefix):
|
||||
"""
|
||||
Check file name valid or not.
|
||||
|
@ -87,6 +86,7 @@ class CheckpointConfig:
|
|||
Can't be used with keep_checkpoint_max at the same time.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
@ -101,7 +101,8 @@ class CheckpointConfig:
|
|||
save_checkpoint_seconds=0,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
integrated_save=True):
|
||||
integrated_save=True,
|
||||
model_type="normal"):
|
||||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
|
@ -115,6 +116,8 @@ class CheckpointConfig:
|
|||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
||||
if keep_checkpoint_per_n_minutes:
|
||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
||||
if model_type:
|
||||
model_type = check_string(model_type, ["normal", "fusion", "quant"])
|
||||
|
||||
self._save_checkpoint_steps = save_checkpoint_steps
|
||||
self._save_checkpoint_seconds = save_checkpoint_seconds
|
||||
|
@ -129,6 +132,7 @@ class CheckpointConfig:
|
|||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||
self._keep_checkpoint_max = 1
|
||||
|
||||
self._model_type = model_type
|
||||
self._integrated_save = check_bool(integrated_save)
|
||||
|
||||
@property
|
||||
|
@ -156,12 +160,18 @@ class CheckpointConfig:
|
|||
"""Get the value of _integrated_save."""
|
||||
return self._integrated_save
|
||||
|
||||
@property
|
||||
def model_type(self):
|
||||
"""Get the value of model_type."""
|
||||
return self._model_type
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
||||
'keep_checkpoint_max': self._keep_checkpoint_max,
|
||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
|
||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
|
||||
'model_type': self._model_type}
|
||||
|
||||
return checkpoint_policy
|
||||
|
||||
|
@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
|
|||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
self._save_ckpt(cb_params)
|
||||
self._save_ckpt(cb_params, self._config.model_type)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
|
@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
|
|||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
return False
|
||||
|
||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
|
@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
|
|||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
|
||||
|
||||
if os.path.exists(gen_file):
|
||||
shutil.move(gen_file, cur_file)
|
||||
|
|
|
@ -76,7 +76,7 @@ class LossMonitor(Callback):
|
|||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num)
|
||||
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1
|
||||
|
||||
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
|
||||
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
|
||||
|
@ -87,7 +87,7 @@ class LossMonitor(Callback):
|
|||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
||||
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
|
||||
cb_params.cur_epoch_num, cb_params.epoch_num,
|
||||
cur_step_in_epoch, int(cb_params.batch_num),
|
||||
step_loss, np.mean(self.losses),
|
||||
step_mseconds), flush=True)
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.common.api import _executor
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import check_input_data
|
||||
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
||||
|
@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|||
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||
|
||||
ModelType = ["normal", "fusion", "quant"]
|
||||
|
||||
|
||||
def _special_process_par(par, new_par):
|
||||
"""
|
||||
|
@ -101,20 +104,22 @@ def _update_param(param, new_param):
|
|||
param.set_parameter_data(type(param.data)(new_param.data))
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpoint_file_name):
|
||||
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
Args:
|
||||
parameter_list (list): Parameters list, each element is a dict
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
ckpoint_file_name (str): Checkpoint file name.
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
model_type (str): The name of model type. Default: "normal".
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
"""
|
||||
logger.info("Execute save checkpoint process.")
|
||||
checkpoint_list = Checkpoint()
|
||||
checkpoint_list.model_type = model_type
|
||||
|
||||
try:
|
||||
for param in parameter_list:
|
||||
|
@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
|||
for dim in param['data'].shape:
|
||||
param_tensor.dims.append(dim)
|
||||
|
||||
with open(ckpoint_file_name, "wb") as f:
|
||||
with open(ckpt_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
logger.info("Save checkpoint process finish.")
|
||||
|
||||
|
||||
def load_checkpoint(ckpoint_file_name, net=None):
|
||||
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
||||
"""
|
||||
Loads checkpoint info from a specified file.
|
||||
|
||||
Args:
|
||||
ckpoint_file_name (str): Checkpoint file name.
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
net (Cell): Cell network. Default: None
|
||||
|
||||
Returns:
|
||||
|
@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
|||
Raises:
|
||||
ValueError: Checkpoint file is incorrect.
|
||||
"""
|
||||
if not isinstance(ckpoint_file_name, str):
|
||||
raise ValueError("The ckpoint_file_name must be String.")
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
|
||||
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
|
||||
if model_type not in ModelType:
|
||||
raise ValueError(f"The model_type is not in {ModelType}.")
|
||||
|
||||
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
|
||||
if os.path.getsize(ckpoint_file_name) == 0:
|
||||
if os.path.getsize(ckpt_file_name) == 0:
|
||||
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
||||
|
||||
logger.info("Execute load checkpoint process.")
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
with open(ckpoint_file_name, "rb") as f:
|
||||
with open(ckpt_file_name, "rb") as f:
|
||||
pb_content = f.read()
|
||||
checkpoint_list.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
|
||||
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
|
||||
raise ValueError(e.__str__())
|
||||
|
||||
parameter_dict = {}
|
||||
|
||||
if model_type != checkpoint_list.model_type:
|
||||
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
|
||||
checkpoint_list.model_type, model_type))
|
||||
try:
|
||||
for element in checkpoint_list.value:
|
||||
data = element.tensor.tensor_content
|
||||
|
@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
|||
logger.info("Load checkpoint process finish.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
|
||||
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
if net:
|
||||
|
@ -303,14 +314,15 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
Args:
|
||||
train_network (Network): The train network for training.
|
||||
ckpoint_file_name (str): The name of checkpoint file.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
|
||||
ckpt_file_name (str): The name of checkpoint file.
|
||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||
"""
|
||||
|
||||
param_dict = {}
|
||||
|
@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
|
|||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
|
||||
save_checkpoint(param_list, ckpoint_file_name)
|
||||
save_checkpoint(param_list, ckpt_file_name, model_type)
|
||||
|
||||
|
||||
def _get_merged_param_data(net, param_name, param_data):
|
||||
|
|
|
@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
|
|||
|
||||
import os
|
||||
import argparse
|
||||
from src.dataset import create_dataset
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet import LeNet5
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet import LeNet5
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
|
@ -49,9 +47,6 @@ if __name__ == "__main__":
|
|||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
repeat_size = cfg.epoch_size
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
|
|
|
@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
|
|||
```bash
|
||||
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
||||
>>> ...
|
||||
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||
```
|
||||
|
||||
Also, you can just run this command instead.
|
||||
|
@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
|
|||
```bash
|
||||
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
|
||||
>>> ...
|
||||
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
|
||||
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
|
||||
```
|
||||
|
||||
### Evaluate quantization aware model
|
||||
|
@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
|
|||
load_param_into_net(network, param_dict)
|
||||
|
||||
# convert funsion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network
|
||||
network = quant.convert_quant_network(network)
|
||||
```
|
||||
|
||||
Also, you can just run this command insread.
|
||||
|
|
|
@ -23,7 +23,6 @@ import argparse
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from src.dataset import create_dataset
|
||||
|
@ -47,16 +46,18 @@ if __name__ == "__main__":
|
|||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
||||
step_size = ds_eval.get_dataset_size()
|
||||
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# define loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
repeat_size = cfg.epoch_size
|
||||
# define network optimization
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
|
||||
# call back and monitor
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
# load check point into network
|
||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
|
|
|
@ -23,7 +23,6 @@ import argparse
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.quant import quant
|
||||
|
@ -48,20 +47,21 @@ if __name__ == "__main__":
|
|||
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
|
||||
step_size = ds_eval.get_dataset_size()
|
||||
|
||||
# define funsion network
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert funsion netwrok to quantization aware network
|
||||
# convert fusion netwrok to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# define loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
# define network optimization
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
|
||||
# call back and monitor
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
|
|
|
@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
|
|||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
|
||||
self.conv1 = nn.Conv2d(channel, 6, 5)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
self.fc1 = nn.Dense(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, self.num_class)
|
||||
|
|
|
@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
|
|||
|
||||
def __init__(self, num_class=10, channel=1):
|
||||
super(LeNet5, self).__init__()
|
||||
self.type = "fusion"
|
||||
self.num_class = num_class
|
||||
|
||||
# change `nn.Conv2d` to `nn.Conv2dBnAct`
|
||||
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu')
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu')
|
||||
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
|
||||
# change `nn.Dense` to `nn.DenseBnAct`
|
||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||
|
|
|
@ -46,16 +46,24 @@ if __name__ == "__main__":
|
|||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
# define network optimization
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||
model_type=network.type)
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
||||
# define model
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== End Training ==============")
|
||||
|
|
|
@ -48,23 +48,30 @@ if __name__ == "__main__":
|
|||
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
# define funsion network
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||
load_param_into_net(network, param_dict)
|
||||
# convert funsion netwrok to quantization aware network
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
# define network optimization
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||
model_type="quant")
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
||||
# define model
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== End Training ==============")
|
||||
|
|
|
@ -85,7 +85,7 @@ if __name__ == '__main__':
|
|||
|
||||
is_ckpt_exist = os.path.exists(ckpt_file_path)
|
||||
if is_ckpt_exist:
|
||||
param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path)
|
||||
param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
export(net, input_data, file_name=model_path_name, file_format='LITE')
|
||||
print("test lenet predict success.")
|
||||
|
|
|
@ -111,19 +111,19 @@ def test_save_checkpoint():
|
|||
os.chmod('./parameters.ckpt', stat.S_IWRITE)
|
||||
os.remove('./parameters.ckpt')
|
||||
|
||||
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
save_checkpoint(parameter_list, ckpoint_file_name)
|
||||
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
save_checkpoint(parameter_list, ckpt_file_name)
|
||||
|
||||
|
||||
def test_load_checkpoint_error_filename():
|
||||
ckpoint_file_name = 1
|
||||
ckpt_file_name = 1
|
||||
with pytest.raises(ValueError):
|
||||
load_checkpoint(ckpoint_file_name)
|
||||
load_checkpoint(ckpt_file_name)
|
||||
|
||||
|
||||
def test_load_checkpoint():
|
||||
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
par_dict = load_checkpoint(ckpoint_file_name)
|
||||
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
par_dict = load_checkpoint(ckpt_file_name)
|
||||
|
||||
assert len(par_dict) == 3
|
||||
assert par_dict['param_test'].name == 'param_test'
|
||||
|
@ -136,17 +136,17 @@ def test_checkpoint_manager():
|
|||
""" test_checkpoint_manager """
|
||||
ckp_mgr = _CheckpointManager()
|
||||
|
||||
ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
|
||||
with open(ckpoint_file_name, 'w'):
|
||||
os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
|
||||
with open(ckpt_file_name, 'w'):
|
||||
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
||||
assert ckp_mgr.ckpoint_num == 1
|
||||
|
||||
ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
|
||||
ckp_mgr.remove_ckpoint_file(ckpt_file_name)
|
||||
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
|
||||
assert ckp_mgr.ckpoint_num == 0
|
||||
assert not os.path.exists(ckpoint_file_name)
|
||||
assert not os.path.exists(ckpt_file_name)
|
||||
|
||||
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
|
||||
another_file_name = os.path.realpath(another_file_name)
|
||||
|
@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
|
|||
|
||||
loss_net = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(loss_net, opt)
|
||||
_exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
|
||||
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
||||
|
||||
load_checkpoint("new_ckpt.ckpt")
|
||||
|
||||
|
|
Loading…
Reference in New Issue