mindspore/tests/st/networks/test_gradient_accumulation.py

221 lines
7.8 KiB
Python

import os
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as CT
import mindspore.dataset.vision.c_transforms as CV
import mindspore.nn as nn
from mindspore import ParameterTuple
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.initializer import Normal
from mindspore.dataset.vision import Inter
from mindspore.nn import Cell
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import save_checkpoint
_sum_op = C.MultitypeFuncGraph("grad_sum_op")
_clear_op = C.MultitypeFuncGraph("clear_op")
@_sum_op.register("Tensor", "Tensor")
def _cumulative_gard(grad_sum, grad):
"""Apply gard sum to cumulative gradient."""
add = P.AssignAdd()
return add(grad_sum, grad)
@_clear_op.register("Tensor", "Tensor")
def _clear_grad_sum(grad_sum, zero):
"""Apply zero to clear grad_sum."""
success = True
success = F.depend(success, F.assign(grad_sum, zero))
return success
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
num_channel (int): Num channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
class TrainForwardBackward(Cell):
def __init__(self, network, optimizer, grad_sum, sens=1.0):
super(TrainForwardBackward, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad_sum = grad_sum
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.hyper_map = C.HyperMap()
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
class TrainOptim(Cell):
def __init__(self, optimizer, grad_sum):
super(TrainOptim, self).__init__(auto_prefix=False)
self.optimizer = optimizer
self.grad_sum = grad_sum
def construct(self):
return self.optimizer(self.grad_sum)
class TrainClear(Cell):
def __init__(self, grad_sum, zeros):
super(TrainClear, self).__init__(auto_prefix=False)
self.grad_sum = grad_sum
self.zeros = zeros
self.hyper_map = C.HyperMap()
def construct(self):
seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros)
return seccess
class GradientAccumulation:
def __init__(self, network, loss_fn, optimizer):
self._network = network
self._loss_fn = loss_fn
self._optimizer = optimizer
params = self._optimizer.parameters
self._grad_sum = params.clone(prefix="grad_sum", init='zeros')
self._zeros = params.clone(prefix="zeros", init='zeros')
self._train_forward_backward = self._build_train_forward_backward_network()
self._train_optim = self._build_train_optim()
self._train_clear = self._build_train_clear()
def _build_train_forward_backward_network(self):
"""Build forward and backward network"""
network = self._network
network = nn.WithLossCell(network, self._loss_fn)
loss_scale = 1.0
network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train()
return network
def _build_train_optim(self):
"""Build optimizer network"""
network = TrainOptim(self._optimizer, self._grad_sum).set_train()
return network
def _build_train_clear(self):
"""Build clear network"""
network = TrainClear(self._grad_sum, self._zeros).set_train()
return network
def train_process(self, epoch, train_dataset, mini_steps=None):
"""
Training process. The data would be passed to network directly.
"""
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch)
for i in range(epoch):
step = 0
for k, next_element in enumerate(dataset_helper):
loss = self._train_forward_backward(*next_element)
if (k + 1) % mini_steps == 0:
step += 1
print("epoch:", i + 1, "step:", step, "loss is ", loss)
self._train_optim()
self._train_clear()
train_dataset.reset()
save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",)
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = CT.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_gradient_accumulation():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
model = GradientAccumulation(network, net_loss, net_opt)
print("============== Starting Training ==============")
model.train_process(2, ds_train, mini_steps=4)