forked from mindspore-Ecosystem/mindspore
!45274 transformer dynamic st case
Merge pull request !45274 from zhangdong/zd_1
This commit is contained in:
commit
0476edad8c
|
@ -0,0 +1,202 @@
|
|||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer training script."""
|
||||
import time
|
||||
import numpy as np
|
||||
from easydict import EasyDict as edict
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
||||
from mindspore.train.callback import Callback, TimeMonitor
|
||||
from mindspore.common import set_seed
|
||||
import mindspore.dataset as de
|
||||
from transformer.transformer_for_train import TransformerNetworkWithLoss, TransformerTrainOneStepWithLossScaleCell
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def get_ms_timestamp():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
|
||||
TIME_STAMP_INIT = False
|
||||
TIME_STAMP_FIRST = 0
|
||||
EPOCH_SIZE = 1
|
||||
BATCH_SIZE = 32
|
||||
LR_BETA2 = 0.997
|
||||
|
||||
lr_schedule = edict({'learning_rate': 2.0, 'warmup_steps': 8000, 'start_decay_step': 16000, 'min_lr': 0.0,})
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1, rank_id=0):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.rank_id = rank_id
|
||||
self.loss_list = []
|
||||
global TIME_STAMP_INIT, TIME_STAMP_FIRST
|
||||
if not TIME_STAMP_INIT:
|
||||
TIME_STAMP_FIRST = get_ms_timestamp()
|
||||
TIME_STAMP_INIT = True
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""Monitor the loss in training."""
|
||||
global TIME_STAMP_FIRST
|
||||
time_stamp_current = get_ms_timestamp()
|
||||
cb_params = run_context.original_args()
|
||||
print("time: {}, epoch: {}, step: {}, outputs are {}".format(time_stamp_current - TIME_STAMP_FIRST,
|
||||
cb_params.cur_epoch_num,
|
||||
cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)))
|
||||
result = cb_params.net_outputs[0].asnumpy()
|
||||
self.loss_list.append(result)
|
||||
|
||||
|
||||
def linear_warmup(warmup_steps, current_step):
|
||||
return min([1.0, float(current_step)/float(warmup_steps)])
|
||||
|
||||
|
||||
def rsqrt_decay(warmup_steps, current_step):
|
||||
return float(max([current_step, warmup_steps])) ** -0.5
|
||||
|
||||
|
||||
def rsqrt_hidden(hidden_size):
|
||||
return float(hidden_size) ** -0.5
|
||||
|
||||
|
||||
def create_dynamic_lr(schedule, training_steps, learning_rate, warmup_steps, hidden_size,
|
||||
start_decay_step=0, min_lr=0.):
|
||||
"""
|
||||
Generate dynamic learning rate.
|
||||
"""
|
||||
if start_decay_step < warmup_steps:
|
||||
start_decay_step = warmup_steps
|
||||
lr = []
|
||||
for current_step in range(1, training_steps+1):
|
||||
cur_lr = 1.0
|
||||
for name in schedule.split("*"):
|
||||
if name == "constant":
|
||||
cur_lr *= float(learning_rate)
|
||||
elif name == "rsqrt_hidden":
|
||||
cur_lr *= rsqrt_hidden(hidden_size)
|
||||
elif name == "linear_warmup":
|
||||
cur_lr *= linear_warmup(warmup_steps, current_step)
|
||||
elif name == "rsqrt_decay":
|
||||
cur_lr *= rsqrt_decay(warmup_steps, current_step-start_decay_step+warmup_steps)
|
||||
else:
|
||||
raise ValueError("unknown learning rate schedule")
|
||||
if warmup_steps < current_step < start_decay_step:
|
||||
cur_lr = lr[-1]
|
||||
if current_step > warmup_steps:
|
||||
cur_lr = max([cur_lr, min_lr])
|
||||
lr.append(cur_lr)
|
||||
return lr
|
||||
|
||||
|
||||
def fun(data, shape):
|
||||
data = data.reshape(shape)
|
||||
return data[0], data[1], data[2], data[3], data[4], data[5]
|
||||
|
||||
|
||||
def create_transformer_dynamic_dataset(rank_size=1, rank_id=0, do_shuffle="true"):
|
||||
dataset = de.MindDataset(
|
||||
"/home/workspace/mindspore_dataset/transformer/test-dynamic-mindrecord",
|
||||
columns_list=["batch_data", "batch_shape"],
|
||||
shuffle=(do_shuffle == "true"), num_shards=rank_size, shard_id=rank_id)
|
||||
|
||||
dataset = dataset.map(fun, input_columns=["batch_data", "batch_shape"],
|
||||
output_columns=["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"],
|
||||
)
|
||||
dataset = dataset.project(["source_eos_ids", "source_eos_mask",
|
||||
"target_sos_ids", "target_sos_mask",
|
||||
"target_eos_ids", "target_eos_mask"])
|
||||
return dataset
|
||||
|
||||
|
||||
def get_train_loss():
|
||||
"""
|
||||
Transformer training.
|
||||
"""
|
||||
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU", reserve_class_name_in_scope=False,
|
||||
enable_graph_kernel=False)
|
||||
# Set mempool block size in PYNATIVE_MODE for improving memory utilization, which will not take effect in GRAPH_MODE
|
||||
if ms.get_context("mode") == ms.PYNATIVE_MODE:
|
||||
ms.set_context(mempool_block_size="31GB")
|
||||
|
||||
device_num = 1
|
||||
rank_id = 0
|
||||
|
||||
dataset = create_transformer_dynamic_dataset(rank_size=device_num, rank_id=rank_id, do_shuffle=True)
|
||||
netwithloss = TransformerNetworkWithLoss(True)
|
||||
|
||||
hidden_size = 1024
|
||||
learning_rate = 1.0
|
||||
lr = Tensor(create_dynamic_lr(schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
|
||||
training_steps=dataset.get_dataset_size() * EPOCH_SIZE,
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=lr_schedule.warmup_steps,
|
||||
hidden_size=hidden_size,
|
||||
start_decay_step=lr_schedule.start_decay_step,
|
||||
min_lr=lr_schedule.min_lr), ms.float32)
|
||||
|
||||
|
||||
optimizer = Adam(netwithloss.trainable_params(), lr, beta2=LR_BETA2)
|
||||
loss_callback = LossCallBack(rank_id=rank_id)
|
||||
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_callback]
|
||||
|
||||
scale_manager = DynamicLossScaleManager(init_loss_scale=1024, scale_factor=2, scale_window=2000)
|
||||
update_cell = scale_manager.get_update_cell()
|
||||
netwithgrads = TransformerTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
data_col = Tensor(shape=[BATCH_SIZE, None], dtype=ms.int64)
|
||||
netwithgrads.set_inputs(data_col, data_col, data_col, data_col, data_col, data_col, data_col)
|
||||
|
||||
netwithgrads.set_train(True)
|
||||
model = Model(netwithgrads)
|
||||
model.train(EPOCH_SIZE, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
loss_list = loss_callback.loss_list
|
||||
return loss_list
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train():
|
||||
"""
|
||||
Feature: Test the simplified dynamic shape transformer network with small data.
|
||||
Description: The sequence length of inputs is dynamic.
|
||||
Expectation: Assert that the training loss of fixed data is consistent with the expected loss.
|
||||
"""
|
||||
graph_loss = get_train_loss()
|
||||
expect_loss = [11.193909]
|
||||
assert np.allclose(graph_loss, expect_loss, 1e-3, 1e-3)
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Transformer for training."""
|
||||
import mindspore
|
||||
import mindspore.ops as ops
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
|
||||
from tests.st.dynamic_shape.transformer.transformer_model import TransformerModel
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 5.0
|
||||
BATCH_SIZE_VALUE = 32
|
||||
|
||||
clip_grad = ops.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type not in (0, 1):
|
||||
return grad
|
||||
dt = ops.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = ops.clip_by_value(grad, ops.cast(ops.tuple_to_array((-clip_value,)), dt),
|
||||
ops.cast(ops.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, ops.cast(ops.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
class TransformerTrainingLoss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(TransformerTrainingLoss, self).__init__(auto_prefix=False)
|
||||
self.vocab_size = 36560
|
||||
self.onehot = ops.OneHot()
|
||||
self.on_value = Tensor(float(1 - 0.1), mindspore.float32)
|
||||
self.off_value = Tensor(0.1 / float(self.vocab_size - 1), mindspore.float32)
|
||||
self.reduce_sum = ops.ReduceSum()
|
||||
self.reduce_mean = ops.ReduceMean()
|
||||
self.reshape = ops.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.flatten = ops.Flatten()
|
||||
self.neg = ops.Neg()
|
||||
self.cast = ops.Cast()
|
||||
self.batch_size = BATCH_SIZE_VALUE
|
||||
|
||||
def construct(self, prediction_scores, label_ids, label_weights, seq_length):
|
||||
"""Defines the computation performed."""
|
||||
flat_shape = (self.batch_size * seq_length,)
|
||||
flat_shape = (-1,)
|
||||
label_ids = self.reshape(label_ids, flat_shape)
|
||||
label_weights = self.cast(self.reshape(label_weights, flat_shape), mindspore.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + \
|
||||
self.cast(ops.tuple_to_array((1e-5,)), mindspore.float32)
|
||||
loss = numerator / denominator
|
||||
return loss
|
||||
|
||||
|
||||
class TransformerNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide transformer training loss through network.
|
||||
|
||||
Args:
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
def __init__(self, is_training, use_one_hot_embeddings=False):
|
||||
super(TransformerNetworkWithLoss, self).__init__(auto_prefix=False)
|
||||
self.transformer = TransformerModel(is_training, use_one_hot_embeddings)
|
||||
self.loss = TransformerTrainingLoss()
|
||||
self.cast = ops.Cast()
|
||||
self.shape = ops.TensorShape()
|
||||
|
||||
def construct(self,
|
||||
source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights):
|
||||
"""Transformer network with loss."""
|
||||
prediction_scores = self.transformer(source_ids, source_mask, target_ids, target_mask)
|
||||
seq_length = self.shape(source_ids)[1]
|
||||
total_loss = self.loss(prediction_scores, label_ids, label_weights, seq_length)
|
||||
return self.cast(total_loss, mindspore.float32)
|
||||
|
||||
|
||||
grad_scale = ops.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = ops.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * ops.cast(reciprocal(scale), ops.dtype(grad))
|
||||
|
||||
_grad_overflow = ops.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = ops.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class TransformerTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of Transformer network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(TransformerTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = ops.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mindspore.float32))
|
||||
|
||||
def construct(self,
|
||||
source_eos_ids,
|
||||
source_eos_mask,
|
||||
target_sos_ids,
|
||||
target_sos_mask,
|
||||
target_eos_ids,
|
||||
target_eos_mask,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
source_ids = source_eos_ids
|
||||
source_mask = source_eos_mask
|
||||
target_ids = target_sos_ids
|
||||
target_mask = target_sos_mask
|
||||
label_ids = target_eos_ids
|
||||
label_weights = target_eos_mask
|
||||
|
||||
weights = self.weights
|
||||
loss = self.network(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(source_ids,
|
||||
source_mask,
|
||||
target_ids,
|
||||
target_mask,
|
||||
label_ids,
|
||||
label_weights,
|
||||
self.cast(scaling_sens,
|
||||
mindspore.float32))
|
||||
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(ops.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(ops.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if not overflow:
|
||||
self.optimizer(grads)
|
||||
return (loss, cond, scaling_sens)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue