modify comment

This commit is contained in:
changzherui 2021-08-22 15:06:55 +08:00
parent 9f08cdc4ab
commit 46319f3198
7 changed files with 53 additions and 40 deletions

View File

@ -1727,7 +1727,6 @@ class Tensor(Tensor_):
Numpy arguments `dtype`, `out` and `where` are not supported.
Args:
self (Tensor): A Tensor to be calculated.
axis (Union[None, int, tuple(int)]): Axis or axes along which the standard
deviation is computed. Default: `None`.

View File

@ -80,13 +80,17 @@ class Callback:
Callback function will execute some operations in the current step or epoch.
It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.callback import Callback
>>> class Print_info(Callback):
>>> def step_end(self, run_context):
>>> cb_params = run_context.original_args()
>>> print("step_num: ", cb_params.cur_step_num)
... def step_end(self, run_context):
... cb_params = run_context.original_args()
... print("step_num: ", cb_params.cur_step_num)
>>>
>>> print_cb = Print_info()
>>> dataset = create_custom_dataset()

View File

@ -69,6 +69,11 @@ class CheckpointConfig:
During the training process, if dataset is transmitted through the data channel,
It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
Otherwise, the time to save the checkpoint may be biased.
It is recommended to set only one save strategy and one keep strategy at the same time.
If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set,
`save_checkpoint_seconds` will be invalid.
If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set,
`keep_checkpoint_per_n_minutes` will be invalid.
Args:
save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
@ -77,13 +82,13 @@ class CheckpointConfig:
keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5.
keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes.
Can't be used with keep_checkpoint_max at the same time. Default: 0.
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
Integrated save function is only supported in automatic parallel scene, not supported
in manual parallel. Default: True.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
with the network in training, the initial value of saved_network will be saved. Default: None.
append_info (list): The information save to checkpoint file. Support "epoch_num""step_num"and dict.
append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and dict.
The key of dict must be str, the value of dict must be one of int float and bool. Default: None.
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
@ -98,25 +103,25 @@ class CheckpointConfig:
>>> from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
>>>
>>> class LeNet5(nn.Cell):
>>> def __init__(self, num_class=10, num_channel=1):
>>> super(LeNet5, self).__init__()
>>> self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
>>> self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
>>> self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
>>> self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
>>> self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
>>> self.relu = nn.ReLU()
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
>>> self.flatten = nn.Flatten()
>>>
>>> def construct(self, x):
>>> x = self.max_pool2d(self.relu(self.conv1(x)))
>>> x = self.max_pool2d(self.relu(self.conv2(x)))
>>> x = self.flatten(x)
>>> x = self.relu(self.fc1(x))
>>> x = self.relu(self.fc2(x))
>>> x = self.fc3(x)
>>> return x
... def __init__(self, num_class=10, num_channel=1):
... super(LeNet5, self).__init__()
... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
... self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
... self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
... self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.flatten = nn.Flatten()
...
... def construct(self, x):
... x = self.max_pool2d(self.relu(self.conv1(x)))
... x = self.max_pool2d(self.relu(self.conv2(x)))
... x = self.flatten(x)
... x = self.relu(self.fc1(x))
... x = self.relu(self.fc2(x))
... x = self.fc3(x)
... return x
>>>
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
@ -465,7 +470,7 @@ class CheckpointManager:
self._ckpoint_filelist = []
files = os.listdir(directory)
for filename in files:
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
mid_name = filename[len(prefix):-5]
flag = not (True in [char.isalpha() for char in mid_name])
if flag:

View File

@ -30,7 +30,7 @@ class LossMonitor(Callback):
If per_print_times is 0, do not print loss.
Args:
per_print_times (int): Print the loss each every seconds. Default: 1.
per_print_times (int): Print the loss every seconds. Default: 1.
Raises:
ValueError: If per_print_times is not an integer or less than zero.

View File

@ -596,6 +596,8 @@ class Model:
Note:
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
When dataset_sink_mode is True, step_end method of the Callback class will be executed when
the epoch_end method is called.
If sink_size > 0, each epoch the dataset can be traversed unlimited times until you get sink_size
elements of the dataset. Next epoch continues to traverse from the end position of the previous traversal.
The interface builds the computational graphs and then executes the computational graphs.
@ -771,6 +773,8 @@ class Model:
Note:
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
of data will be transferred one by one. The limitation of data transmission per time is 256M.
When dataset_sink_mode is True, step_end method of the Callback class will be executed when
the epoch_end method is called.
Args:
valid_dataset (Dataset): Dataset to evaluate the model.

View File

@ -193,7 +193,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
"""
Saves checkpoint info to a specified file.
Save checkpoint info to a specified file.
Args:
save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
@ -370,7 +370,7 @@ def load(file_name, **kwargs):
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
"""
Loads checkpoint info from a specified file.
Load checkpoint info from a specified file.
Args:
ckpt_file_name (str): Checkpoint file name.
@ -499,7 +499,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
def load_param_into_net(net, parameter_dict, strict_load=False):
"""
Loads parameters into network.
Load parameters into network.
Args:
net (Cell): Cell network.
@ -824,6 +824,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
index = 0
graphproto = graph_proto()
data_size = 0
@ -980,7 +981,7 @@ def _quant_export(network, *inputs, file_format, **kwargs):
def parse_print(print_file_name):
"""
Loads Print data from a specified file.
Load Print data from a specified file.
Args:
print_file_name (str): The file name of saved print data.
@ -1373,7 +1374,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
def async_ckpt_thread_status():
"""
Get async save checkpoint thread status.
Get the status of asynchronous save checkpoint thread.
Returns:
True, Asynchronous save checkpoint thread is running.

View File

@ -166,7 +166,7 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """
ckp_mgr = _CheckpointManager()
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt')
with open(ckpt_file_name, 'w'):
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
@ -178,7 +178,7 @@ def test_checkpoint_manager():
assert ckp_mgr.ckpoint_num == 0
assert not os.path.exists(ckpt_file_name)
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
another_file_name = os.path.join(_cur_dir, './test-2_1.ckpt')
another_file_name = os.path.realpath(another_file_name)
with open(another_file_name, 'w'):
os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
@ -191,9 +191,9 @@ def test_checkpoint_manager():
assert not os.path.exists(another_file_name)
# test keep_one_ckpoint_per_minutes
file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
file1 = os.path.realpath(os.path.join(_cur_dir, './time_file-1_1.ckpt'))
file2 = os.path.realpath(os.path.join(_cur_dir, './time_file-2_1.ckpt'))
file3 = os.path.realpath(os.path.join(_cur_dir, './time_file-3_1.ckpt'))
with open(file1, 'w'):
os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
with open(file2, 'w'):
@ -206,9 +206,9 @@ def test_checkpoint_manager():
ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
assert ckp_mgr.ckpoint_num == 1
if os.path.exists(_cur_dir + '/time_file1.ckpt'):
os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
os.remove(_cur_dir + '/time_file1.ckpt')
if os.path.exists(_cur_dir + '/time_file-1_1.ckpt'):
os.chmod(_cur_dir + '/time_file-1_1.ckpt', stat.S_IWRITE)
os.remove(_cur_dir + '/time_file-1_1.ckpt')
def test_load_param_into_net_error_net():