forked from mindspore-Ecosystem/mindspore
!6357 add grad acc test case
Merge pull request !6357 from jinyaohui/master
This commit is contained in:
commit
daff211538
|
@ -1816,8 +1816,6 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
|
||||
}
|
||||
|
||||
UpdataParam(func_graph, cell);
|
||||
|
||||
// ret = cell_obj(*arg, *kwargs)
|
||||
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
|
||||
|
||||
|
|
|
@ -0,0 +1,220 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as CT
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.nn as nn
|
||||
from mindspore import ParameterTuple
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import Normal
|
||||
from mindspore.dataset.vision import Inter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.dataset_helper import DatasetHelper
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
_sum_op = C.MultitypeFuncGraph("grad_sum_op")
|
||||
_clear_op = C.MultitypeFuncGraph("clear_op")
|
||||
|
||||
|
||||
@_sum_op.register("Tensor", "Tensor")
|
||||
def _cumulative_gard(grad_sum, grad):
|
||||
"""Apply gard sum to cumulative gradient."""
|
||||
add = P.AssignAdd()
|
||||
return add(grad_sum, grad)
|
||||
|
||||
|
||||
@_clear_op.register("Tensor", "Tensor")
|
||||
def _clear_grad_sum(grad_sum, zero):
|
||||
"""Apply zero to clear grad_sum."""
|
||||
success = True
|
||||
success = F.depend(success, F.assign(grad_sum, zero))
|
||||
return success
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
|
||||
Args:
|
||||
num_class (int): Num classes. Default: 10.
|
||||
num_channel (int): Num channels. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor
|
||||
Examples:
|
||||
>>> LeNet(num_class=10)
|
||||
"""
|
||||
def __init__(self, num_class=10, num_channel=1):
|
||||
super(LeNet5, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
||||
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
||||
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
||||
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.max_pool2d(self.relu(self.conv1(x)))
|
||||
x = self.max_pool2d(self.relu(self.conv2(x)))
|
||||
x = self.flatten(x)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class TrainForwardBackward(Cell):
|
||||
def __init__(self, network, optimizer, grad_sum, sens=1.0):
|
||||
super(TrainForwardBackward, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad_sum = grad_sum
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
return F.depend(loss, self.hyper_map(F.partial(_sum_op), self.grad_sum, grads))
|
||||
|
||||
|
||||
class TrainOptim(Cell):
|
||||
def __init__(self, optimizer, grad_sum):
|
||||
super(TrainOptim, self).__init__(auto_prefix=False)
|
||||
self.optimizer = optimizer
|
||||
self.grad_sum = grad_sum
|
||||
|
||||
def construct(self):
|
||||
return self.optimizer(self.grad_sum)
|
||||
|
||||
|
||||
class TrainClear(Cell):
|
||||
def __init__(self, grad_sum, zeros):
|
||||
super(TrainClear, self).__init__(auto_prefix=False)
|
||||
self.grad_sum = grad_sum
|
||||
self.zeros = zeros
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self):
|
||||
seccess = self.hyper_map(F.partial(_clear_op), self.grad_sum, self.zeros)
|
||||
return seccess
|
||||
|
||||
|
||||
class GradientAccumulation:
|
||||
def __init__(self, network, loss_fn, optimizer):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
|
||||
params = self._optimizer.parameters
|
||||
self._grad_sum = params.clone(prefix="grad_sum", init='zeros')
|
||||
self._zeros = params.clone(prefix="zeros", init='zeros')
|
||||
self._train_forward_backward = self._build_train_forward_backward_network()
|
||||
self._train_optim = self._build_train_optim()
|
||||
self._train_clear = self._build_train_clear()
|
||||
|
||||
def _build_train_forward_backward_network(self):
|
||||
"""Build forward and backward network"""
|
||||
network = self._network
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
loss_scale = 1.0
|
||||
network = TrainForwardBackward(network, self._optimizer, self._grad_sum, loss_scale).set_train()
|
||||
return network
|
||||
|
||||
def _build_train_optim(self):
|
||||
"""Build optimizer network"""
|
||||
network = TrainOptim(self._optimizer, self._grad_sum).set_train()
|
||||
return network
|
||||
|
||||
def _build_train_clear(self):
|
||||
"""Build clear network"""
|
||||
network = TrainClear(self._grad_sum, self._zeros).set_train()
|
||||
return network
|
||||
|
||||
def train_process(self, epoch, train_dataset, mini_steps=None):
|
||||
"""
|
||||
Training process. The data would be passed to network directly.
|
||||
"""
|
||||
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False, epoch_num=epoch)
|
||||
|
||||
for i in range(epoch):
|
||||
step = 0
|
||||
for k, next_element in enumerate(dataset_helper):
|
||||
loss = self._train_forward_backward(*next_element)
|
||||
if (k + 1) % mini_steps == 0:
|
||||
step += 1
|
||||
print("epoch:", i + 1, "step:", step, "loss is ", loss)
|
||||
self._train_optim()
|
||||
self._train_clear()
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt",)
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
hwc2chw_op = CV.HWC2CHW()
|
||||
type_cast_op = CT.TypeCast(mstype.int32)
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_gradient_accumulation():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
ds_train = create_dataset(os.path.join("/home/workspace/mindspore_dataset/mnist", "train"), 32)
|
||||
|
||||
network = LeNet5(10)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
model = GradientAccumulation(network, net_loss, net_opt)
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train_process(2, ds_train, mini_steps=4)
|
Loading…
Reference in New Issue