Merge pull request !40364 from changzherui/add_ckpt_st
This commit is contained in:
i-robot 2022-08-15 12:32:24 +00:00 committed by Gitee
commit 1289a3583e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 182 additions and 0 deletions

View File

@ -0,0 +1,182 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import stat
import pytest
import mindspore.context as context
import mindspore.dataset as ds
import mindspore.dataset.transforms as C
import mindspore.dataset.vision as CV
import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore import load_checkpoint
from mindspore.common import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.train import Model
from mindspore.common.initializer import TruncatedNormal
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1):
"""create dataset for train or test"""
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
class ErrorCallback(Callback):
def __init__(self, epoch_num):
self.epoch_num = epoch_num
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
if epoch_num == self.epoch_num:
raise RuntimeError("Exec runtime error.")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ckpt_append_info():
"""
Feature: Save append info during save ckpt.
Description: Save append info during save ckpt.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt)
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
cb_config = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
append_info=["epoch_num", "step_num"])
ckpoint_cb = ModelCheckpoint(prefix='append_info', directory="./", config=cb_config)
model.train(3, ds_train, callbacks=ckpoint_cb, dataset_sink_mode=True)
file_list = os.listdir(os.getcwd())
ckpt_list = [k for k in file_list if k.startswith("append_info")]
ckpt_1 = [k for k in ckpt_list if k.startswith("append_info-2")]
dict_1 = load_checkpoint(ckpt_1[0])
assert dict_1["epoch_num"].data.asnumpy() == 2
for file_name in ckpt_list:
if os.path.exists(file_name):
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_ckpt_error():
"""
Feature: Save ckpt when error in train.
Description: Save ckpt when error in train.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = Model(network, net_loss, net_opt)
print("============== Starting Training ==============")
ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1)
cb_config = CheckpointConfig(save_checkpoint_steps=10000, exception_save=True)
ckpoint_cb = ModelCheckpoint(prefix='error_ckpt', directory="./", config=cb_config)
with pytest.raises(RuntimeError):
model.train(3, ds_train, callbacks=[ckpoint_cb, ErrorCallback(2)], dataset_sink_mode=True)
file_list = os.listdir(os.getcwd())
ckpt_list = [k for k in file_list if k.endswith("_breakpoint.ckpt")]
assert os.path.exists(ckpt_list[0])
if os.path.exists(ckpt_list[0]):
os.chmod(ckpt_list[0], stat.S_IWRITE)
os.remove(ckpt_list[0])