forked from mindspore-Ecosystem/mindspore
modify comment
This commit is contained in:
parent
9f08cdc4ab
commit
46319f3198
|
@ -1727,7 +1727,6 @@ class Tensor(Tensor_):
|
|||
Numpy arguments `dtype`, `out` and `where` are not supported.
|
||||
|
||||
Args:
|
||||
self (Tensor): A Tensor to be calculated.
|
||||
axis (Union[None, int, tuple(int)]): Axis or axes along which the standard
|
||||
deviation is computed. Default: `None`.
|
||||
|
||||
|
|
|
@ -80,13 +80,17 @@ class Callback:
|
|||
|
||||
Callback function will execute some operations in the current step or epoch.
|
||||
|
||||
It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`,
|
||||
`loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`,
|
||||
`cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import Model, nn
|
||||
>>> from mindspore.train.callback import Callback
|
||||
>>> class Print_info(Callback):
|
||||
>>> def step_end(self, run_context):
|
||||
>>> cb_params = run_context.original_args()
|
||||
>>> print("step_num: ", cb_params.cur_step_num)
|
||||
... def step_end(self, run_context):
|
||||
... cb_params = run_context.original_args()
|
||||
... print("step_num: ", cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> print_cb = Print_info()
|
||||
>>> dataset = create_custom_dataset()
|
||||
|
|
|
@ -69,6 +69,11 @@ class CheckpointConfig:
|
|||
During the training process, if dataset is transmitted through the data channel,
|
||||
It is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size.
|
||||
Otherwise, the time to save the checkpoint may be biased.
|
||||
It is recommended to set only one save strategy and one keep strategy at the same time.
|
||||
If both `save_checkpoint_steps` and `save_checkpoint_seconds` are set,
|
||||
`save_checkpoint_seconds` will be invalid.
|
||||
If both `keep_checkpoint_max` and `keep_checkpoint_per_n_minutes` are set,
|
||||
`keep_checkpoint_per_n_minutes` will be invalid.
|
||||
|
||||
Args:
|
||||
save_checkpoint_steps (int): Steps to save checkpoint. Default: 1.
|
||||
|
@ -77,13 +82,13 @@ class CheckpointConfig:
|
|||
keep_checkpoint_max (int): Maximum number of checkpoint files can be saved. Default: 5.
|
||||
keep_checkpoint_per_n_minutes (int): Save the checkpoint file every `keep_checkpoint_per_n_minutes` minutes.
|
||||
Can't be used with keep_checkpoint_max at the same time. Default: 0.
|
||||
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
|
||||
integrated_save (bool): Whether to merge and save the split Tensor in the automatic parallel scenario.
|
||||
Integrated save function is only supported in automatic parallel scene, not supported
|
||||
in manual parallel. Default: True.
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
|
||||
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
||||
with the network in training, the initial value of saved_network will be saved. Default: None.
|
||||
append_info (list): The information save to checkpoint file. Support "epoch_num"、"step_num"、and dict.
|
||||
append_info (list): The information save to checkpoint file. Support "epoch_num", "step_num" and dict.
|
||||
The key of dict must be str, the value of dict must be one of int float and bool. Default: None.
|
||||
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
||||
is not required. Default: None.
|
||||
|
@ -98,25 +103,25 @@ class CheckpointConfig:
|
|||
>>> from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
>>>
|
||||
>>> class LeNet5(nn.Cell):
|
||||
>>> def __init__(self, num_class=10, num_channel=1):
|
||||
>>> super(LeNet5, self).__init__()
|
||||
>>> self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
||||
>>> self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
>>> self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
||||
>>> self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
||||
>>> self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.max_pool2d(self.relu(self.conv1(x)))
|
||||
>>> x = self.max_pool2d(self.relu(self.conv2(x)))
|
||||
>>> x = self.flatten(x)
|
||||
>>> x = self.relu(self.fc1(x))
|
||||
>>> x = self.relu(self.fc2(x))
|
||||
>>> x = self.fc3(x)
|
||||
>>> return x
|
||||
... def __init__(self, num_class=10, num_channel=1):
|
||||
... super(LeNet5, self).__init__()
|
||||
... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
||||
... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
... self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
||||
... self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
||||
... self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
||||
... self.relu = nn.ReLU()
|
||||
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
... self.flatten = nn.Flatten()
|
||||
...
|
||||
... def construct(self, x):
|
||||
... x = self.max_pool2d(self.relu(self.conv1(x)))
|
||||
... x = self.max_pool2d(self.relu(self.conv2(x)))
|
||||
... x = self.flatten(x)
|
||||
... x = self.relu(self.fc1(x))
|
||||
... x = self.relu(self.fc2(x))
|
||||
... x = self.fc3(x)
|
||||
... return x
|
||||
>>>
|
||||
>>> net = LeNet5()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
@ -465,7 +470,7 @@ class CheckpointManager:
|
|||
self._ckpoint_filelist = []
|
||||
files = os.listdir(directory)
|
||||
for filename in files:
|
||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
|
||||
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
|
||||
mid_name = filename[len(prefix):-5]
|
||||
flag = not (True in [char.isalpha() for char in mid_name])
|
||||
if flag:
|
||||
|
|
|
@ -30,7 +30,7 @@ class LossMonitor(Callback):
|
|||
If per_print_times is 0, do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print the loss each every seconds. Default: 1.
|
||||
per_print_times (int): Print the loss every seconds. Default: 1.
|
||||
|
||||
Raises:
|
||||
ValueError: If per_print_times is not an integer or less than zero.
|
||||
|
|
|
@ -596,6 +596,8 @@ class Model:
|
|||
Note:
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
When dataset_sink_mode is True, step_end method of the Callback class will be executed when
|
||||
the epoch_end method is called.
|
||||
If sink_size > 0, each epoch the dataset can be traversed unlimited times until you get sink_size
|
||||
elements of the dataset. Next epoch continues to traverse from the end position of the previous traversal.
|
||||
The interface builds the computational graphs and then executes the computational graphs.
|
||||
|
@ -771,6 +773,8 @@ class Model:
|
|||
Note:
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
When dataset_sink_mode is True, step_end method of the Callback class will be executed when
|
||||
the epoch_end method is called.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
|
|
|
@ -193,7 +193,7 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
||||
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
Save checkpoint info to a specified file.
|
||||
|
||||
Args:
|
||||
save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
|
||||
|
@ -370,7 +370,7 @@ def load(file_name, **kwargs):
|
|||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
|
||||
"""
|
||||
Loads checkpoint info from a specified file.
|
||||
Load checkpoint info from a specified file.
|
||||
|
||||
Args:
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
|
@ -499,7 +499,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|||
|
||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||
"""
|
||||
Loads parameters into network.
|
||||
Load parameters into network.
|
||||
|
||||
Args:
|
||||
net (Cell): Cell network.
|
||||
|
@ -824,6 +824,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
if os.path.exists(data_path):
|
||||
shutil.rmtree(data_path)
|
||||
os.makedirs(data_path, exist_ok=True)
|
||||
os.chmod(data_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
|
||||
index = 0
|
||||
graphproto = graph_proto()
|
||||
data_size = 0
|
||||
|
@ -980,7 +981,7 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
|
||||
def parse_print(print_file_name):
|
||||
"""
|
||||
Loads Print data from a specified file.
|
||||
Load Print data from a specified file.
|
||||
|
||||
Args:
|
||||
print_file_name (str): The file name of saved print data.
|
||||
|
@ -1373,7 +1374,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
|
||||
def async_ckpt_thread_status():
|
||||
"""
|
||||
Get async save checkpoint thread status.
|
||||
Get the status of asynchronous save checkpoint thread.
|
||||
|
||||
Returns:
|
||||
True, Asynchronous save checkpoint thread is running.
|
||||
|
|
|
@ -166,7 +166,7 @@ def test_checkpoint_manager():
|
|||
""" test_checkpoint_manager """
|
||||
ckp_mgr = _CheckpointManager()
|
||||
|
||||
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
|
||||
ckpt_file_name = os.path.join(_cur_dir, './test-1_1.ckpt')
|
||||
with open(ckpt_file_name, 'w'):
|
||||
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
|
@ -178,7 +178,7 @@ def test_checkpoint_manager():
|
|||
assert ckp_mgr.ckpoint_num == 0
|
||||
assert not os.path.exists(ckpt_file_name)
|
||||
|
||||
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
|
||||
another_file_name = os.path.join(_cur_dir, './test-2_1.ckpt')
|
||||
another_file_name = os.path.realpath(another_file_name)
|
||||
with open(another_file_name, 'w'):
|
||||
os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
@ -191,9 +191,9 @@ def test_checkpoint_manager():
|
|||
assert not os.path.exists(another_file_name)
|
||||
|
||||
# test keep_one_ckpoint_per_minutes
|
||||
file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
|
||||
file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
|
||||
file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
|
||||
file1 = os.path.realpath(os.path.join(_cur_dir, './time_file-1_1.ckpt'))
|
||||
file2 = os.path.realpath(os.path.join(_cur_dir, './time_file-2_1.ckpt'))
|
||||
file3 = os.path.realpath(os.path.join(_cur_dir, './time_file-3_1.ckpt'))
|
||||
with open(file1, 'w'):
|
||||
os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
|
||||
with open(file2, 'w'):
|
||||
|
@ -206,9 +206,9 @@ def test_checkpoint_manager():
|
|||
ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
|
||||
ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
|
||||
assert ckp_mgr.ckpoint_num == 1
|
||||
if os.path.exists(_cur_dir + '/time_file1.ckpt'):
|
||||
os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
|
||||
os.remove(_cur_dir + '/time_file1.ckpt')
|
||||
if os.path.exists(_cur_dir + '/time_file-1_1.ckpt'):
|
||||
os.chmod(_cur_dir + '/time_file-1_1.ckpt', stat.S_IWRITE)
|
||||
os.remove(_cur_dir + '/time_file-1_1.ckpt')
|
||||
|
||||
|
||||
def test_load_param_into_net_error_net():
|
||||
|
|
Loading…
Reference in New Issue