remove build-in function in network script
This commit is contained in:
parent
a050c265de
commit
9a31b8d54a
|
@ -539,18 +539,17 @@ class EmbeddingThor(Cell):
|
|||
embedding_size (int): The size of each embedding vector.
|
||||
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
|
||||
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializes the embedding_table.
|
||||
Refer to class `initializer` for the values of string when a string
|
||||
is specified. Default: 'normal'.
|
||||
Refer to class `initializer` for the values of string when a string is specified. Default: 'normal'.
|
||||
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
|
||||
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
|
||||
will be initialized to zero. Default: None. The feature is inactivated.
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
|
||||
- **input** (Tensor) - Tensor of input shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
|
||||
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
|
||||
be zero.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
||||
Tensor of output shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Embedding(20000, 768, True)
|
||||
|
|
|
@ -612,8 +612,8 @@ class ThorAscend(Optimizer):
|
|||
|
||||
def _process_matrix_init_and_weight_idx_map(self, net):
|
||||
"""for Ascend, process matrix init shape, and get weight idx map"""
|
||||
layer_type_map = get_net_layertype_mask(net)
|
||||
layer_counter = 0
|
||||
layer_type_map = get_net_layertype_mask(net)
|
||||
for idx in range(len(self.params)):
|
||||
layer_type = layer_type_map[layer_counter]
|
||||
weight = self.params[idx]
|
||||
|
|
|
@ -136,7 +136,7 @@ def bert_predict():
|
|||
'''
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
dataset = get_enwiki_512_dataset(bert_net_cfg.batch_size, 1)
|
||||
dataset = get_enwiki_512_dataset(cfg.batch_size, 1)
|
||||
net_for_pretraining = BertPretrainEva(bert_net_cfg)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(cfg.finetune_ckpt)
|
||||
|
|
|
@ -19,13 +19,6 @@ python run_pretrain.py
|
|||
|
||||
import argparse
|
||||
import os
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from src.bert_net_config import bert_net_cfg
|
||||
from src.config import cfg
|
||||
from src.dataset import create_bert_dataset
|
||||
from src.lr_generator import get_bert_lr, get_bert_damping
|
||||
from src.model_thor import Model
|
||||
from src.utils import LossCallBack
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
|
@ -35,6 +28,14 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMoni
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.train_thor import ConvertModelUtils
|
||||
from mindspore.nn.optim import thor
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from src.dataset import create_bert_dataset
|
||||
from src.config import cfg, bert_net_cfg
|
||||
from src.utils import LossCallBack
|
||||
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
@ -69,17 +70,19 @@ def _set_bert_all_reduce_split():
|
|||
def _get_optimizer(args_opt, network):
|
||||
"""get thor optimizer."""
|
||||
if cfg.optimizer == "Thor":
|
||||
if args_opt.distribute == "true":
|
||||
from src.thor_for_bert_arg import THOR
|
||||
else:
|
||||
from src.thor_for_bert import THOR
|
||||
lr = get_bert_lr()
|
||||
damping = get_bert_damping()
|
||||
optimizer = THOR(filter(lambda x: x.requires_grad, network.get_parameters()), lr, cfg.Thor.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, network.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, network.get_parameters()),
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||
bert_net_cfg.batch_size, damping)
|
||||
from src.utils import get_bert_thor_lr, get_bert_thor_damping
|
||||
lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
|
||||
damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
|
||||
cfg.Thor.damping_total_steps)
|
||||
split_indices = None
|
||||
if bert_net_cfg.num_hidden_layers == 12 and not bert_net_cfg.use_relative_positions:
|
||||
split_indices = [28, 55, 77]
|
||||
elif bert_net_cfg.num_hidden_layers == 24 and not bert_net_cfg.use_relative_positions:
|
||||
split_indices = [38, 93, 149]
|
||||
optimizer = thor(network, lr, damping, cfg.Thor.momentum,
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
split_indices=split_indices, enable_clip_grad=True, frequency=cfg.Thor.frequency)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Thor]".format(cfg.optimizer))
|
||||
return optimizer
|
||||
|
@ -113,7 +116,6 @@ def run_pretrain():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
|
||||
device_id=args_opt.device_id, save_graphs=False)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(max_call_depth=3000)
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
|
@ -144,7 +146,7 @@ def run_pretrain():
|
|||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
optimizer = _get_optimizer(args_opt, net_with_loss)
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size())]
|
||||
if args_opt.enable_save_ckpt == "true" and rank == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
|
@ -162,11 +164,13 @@ def run_pretrain():
|
|||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, sens=cfg.Thor.loss_scale,
|
||||
enable_clip_grad=False)
|
||||
|
||||
model = Model(net_with_grads, frequency=cfg.Thor.frequency)
|
||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
model = Model(net_with_grads)
|
||||
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer)
|
||||
model.train(new_repeat_count, ds, callbacks=callback,
|
||||
dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -15,17 +15,22 @@
|
|||
"""Bert Init."""
|
||||
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
||||
BertTrainAccumulationAllReducePostWithLossScaleCell, \
|
||||
BertTrainOneStepWithLossScaleCellForAdam
|
||||
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||
SaturateCast, CreateAttentionMaskFromInputMask
|
||||
|
||||
__all__ = [
|
||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
|
||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell",
|
||||
"BertTrainOneStepWithLossScaleCell", "BertTrainAccumulationAllReduceEachWithLossScaleCell",
|
||||
"BertTrainAccumulationAllReducePostWithLossScaleCell",
|
||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
||||
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
||||
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask",
|
||||
"BertTrainOneStepWithLossScaleCellForAdam"
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,26 +16,19 @@
|
|||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore import context
|
||||
from .bert_model import BertModel
|
||||
from .config import cfg
|
||||
from .lr_generator import get_bert_damping
|
||||
from .thor_layer import Dense_Thor
|
||||
|
||||
damping = get_bert_damping()
|
||||
loss_scale = cfg.Thor.loss_scale
|
||||
frequency = cfg.Thor.frequency
|
||||
batch_size = cfg.Thor.batch_size
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
|
@ -84,16 +77,10 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
self.gather = P.Gather()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = Dense_Thor(in_channels=self.width,
|
||||
out_channels=config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=config.hidden_act,
|
||||
batch_size=batch_size).to_float(config.compute_type)
|
||||
self.dense = nn.Dense(self.width,
|
||||
config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
activation=config.hidden_act).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
self.output_bias = Parameter(
|
||||
initializer(
|
||||
|
@ -102,9 +89,8 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.shape_flat_offsets = (-1, 1)
|
||||
self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
|
||||
self.last_idx = (-1,)
|
||||
self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
|
||||
self.shape_flat_sequence_tensor = (-1, self.width)
|
||||
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
|
||||
self.cast = P.Cast()
|
||||
self.compute_type = config.compute_type
|
||||
|
@ -114,9 +100,9 @@ class GetMaskedLMOutput(nn.Cell):
|
|||
input_tensor,
|
||||
output_weights,
|
||||
positions):
|
||||
"""construct of GetMaskedLMOutput"""
|
||||
flat_offsets = self.reshape(
|
||||
self.rng * self.seq_length_tensor, self.shape_flat_offsets)
|
||||
"""Get output log_probs"""
|
||||
rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0]))
|
||||
flat_offsets = self.reshape(rng * self.seq_length_tensor, self.shape_flat_offsets)
|
||||
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
|
||||
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
||||
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
|
||||
|
@ -264,7 +250,7 @@ class BertNetworkWithLoss(nn.Cell):
|
|||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""construct of BertNetworkWithLoss"""
|
||||
"""Get pre-training loss"""
|
||||
prediction_scores, seq_relationship_score = \
|
||||
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||
total_loss = self.loss(prediction_scores, seq_relationship_score,
|
||||
|
@ -283,11 +269,14 @@ class BertTrainOneStepCell(nn.TrainOneStepCell):
|
|||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: False.
|
||||
"""
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=False):
|
||||
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.enable_clip_grad = enable_clip_grad
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
@ -319,7 +308,8 @@ class BertTrainOneStepCell(nn.TrainOneStepCell):
|
|||
masked_lm_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.enable_clip_grad:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
@ -334,7 +324,16 @@ def tensor_grad_scale(scale, grad):
|
|||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
|
||||
grad_overflow = P.FloatStatus()
|
||||
|
||||
|
||||
@_grad_overflow.register("Tensor")
|
||||
def _tensor_grad_overflow(grad):
|
||||
return grad_overflow(grad)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
|
@ -348,15 +347,208 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
Different from BertTrainOneStepWithLossScaleCell, the optimizer takes the overflow
|
||||
condition as input.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCellForAdam, self).__init__(network, optimizer, scale_update_cell)
|
||||
self.cast = P.Cast()
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
cond = self.get_overflow_status(status, grads)
|
||||
overflow = cond
|
||||
if self.loss_scaling_manager is not None:
|
||||
overflow = self.loss_scaling_manager(scaling_sens, cond)
|
||||
succ = self.optimizer(grads, overflow)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
cast = P.Cast()
|
||||
add_grads = C.MultitypeFuncGraph("add_grads")
|
||||
|
||||
|
||||
@add_grads.register("Tensor", "Tensor")
|
||||
def _add_grads(accu_grad, grad):
|
||||
return accu_grad + cast(grad, mstype.float32)
|
||||
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
accumulate_accu_grads = C.MultitypeFuncGraph("accumulate_accu_grads")
|
||||
|
||||
@accumulate_accu_grads.register("Tensor", "Tensor")
|
||||
def _accumulate_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
zeroslike = P.ZerosLike()
|
||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will only be implemented in the weight updated step,
|
||||
i.e. the sub-step after gradients accumulated N times.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
|
@ -366,6 +558,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
|
@ -373,6 +568,10 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
|
@ -406,6 +605,13 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
init = F.depend(init, loss)
|
||||
clear_status = self.clear_status(init)
|
||||
scaling_sens = F.depend(scaling_sens, clear_status)
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
|
@ -415,26 +621,190 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
|||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
init = F.depend(init, grads)
|
||||
|
||||
accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads)
|
||||
mean_loss = F.depend(mean_loss, accu_succ)
|
||||
|
||||
init = F.depend(init, mean_loss)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(self.accu_grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
accu_overflow = F.depend(accu_overflow, grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
overflow = F.depend(overflow, accu_succ)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class BertTrainAccumulationAllReduceEachWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
To mimic higher batch size, gradients are accumulated N times before weight update.
|
||||
|
||||
For distribution mode, allreduce will be implemented after each sub-step and the trailing time
|
||||
will be overided by backend optimization pass.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False):
|
||||
super(BertTrainAccumulationAllReduceEachWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
|
||||
|
||||
accu_grads = self.hyper_map(add_grads, self.accu_grads, grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), accu_grads)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
flag_reduce = self.overflow_reducer(flag_sum)
|
||||
overflow = self.less_equal(self.base, flag_reduce)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
else:
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
if self.enable_global_norm:
|
||||
grads = C.clip_by_global_norm(grads, 1.0, None)
|
||||
else:
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
succ = F.depend(succ, accu_succ)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
|
|
@ -14,27 +14,17 @@
|
|||
# ============================================================================
|
||||
"""Bert model."""
|
||||
|
||||
import copy
|
||||
import math
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from .config import cfg
|
||||
from .lr_generator import get_bert_damping
|
||||
from .thor_layer import Dense_Thor, Embedding_Thor
|
||||
|
||||
damping = get_bert_damping()
|
||||
loss_scale = cfg.Thor.loss_scale
|
||||
frequency = cfg.Thor.frequency
|
||||
batch_size = cfg.Thor.batch_size
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class BertConfig:
|
||||
|
@ -42,7 +32,6 @@ class BertConfig:
|
|||
Configuration for `BertModel`.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence. Default: 128.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 32000.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
|
@ -62,16 +51,10 @@ class BertConfig:
|
|||
type_vocab_size (int): Size of token type vocab. Default: 16.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||
dataset. Default: True.
|
||||
token_type_ids_from_dataset (bool): Specifies whether to use the token type ids that loaded
|
||||
from dataset. Default: True.
|
||||
dtype (:class:`mindspore.dtype`): Data type of the input. Default: mstype.float32.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=32000,
|
||||
hidden_size=768,
|
||||
|
@ -85,12 +68,8 @@ class BertConfig:
|
|||
type_vocab_size=16,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float32,
|
||||
enable_fused_layernorm=False):
|
||||
self.batch_size = batch_size
|
||||
compute_type=mstype.float32):
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -103,12 +82,9 @@ class BertConfig:
|
|||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.input_mask_from_dataset = input_mask_from_dataset
|
||||
self.token_type_ids_from_dataset = token_type_ids_from_dataset
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.dtype = dtype
|
||||
self.compute_type = compute_type
|
||||
self.enable_fused_layernorm = enable_fused_layernorm
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
|
@ -123,7 +99,6 @@ class EmbeddingLookup(nn.Cell):
|
|||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
|
@ -147,7 +122,7 @@ class EmbeddingLookup(nn.Cell):
|
|||
self.shape = tuple(embedding_shape)
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""construct of EmbeddingLookup"""
|
||||
"""Get output and embeddings lookup table"""
|
||||
extended_ids = self.expand(input_ids, -1)
|
||||
flat_ids = self.reshape(extended_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
|
@ -176,7 +151,6 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
model. Default: 512.
|
||||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
|
@ -192,16 +166,10 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.token_type_vocab_size = token_type_vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.token_type_embedding = Embedding_Thor(
|
||||
self.token_type_embedding = nn.Embedding(
|
||||
vocab_size=token_type_vocab_size,
|
||||
embedding_size=embedding_size,
|
||||
embedding_shape=embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
batch_size=batch_size,
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
use_one_hot=use_one_hot_embeddings)
|
||||
self.shape_flat = (-1,)
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
|
@ -213,30 +181,23 @@ class EmbeddingPostprocessor(nn.Cell):
|
|||
self.gather = P.Gather()
|
||||
self.use_relative_positions = use_relative_positions
|
||||
self.slice = P.StridedSlice()
|
||||
_, seq, width = self.shape
|
||||
position_embedding_shape = [1, seq, width]
|
||||
self.full_position_embedding = Embedding_Thor(
|
||||
_, seq, _ = self.shape
|
||||
self.full_position_embedding = nn.Embedding(
|
||||
vocab_size=max_position_embeddings,
|
||||
embedding_size=embedding_size,
|
||||
embedding_shape=position_embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
batch_size=batch_size,
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||
use_one_hot=False)
|
||||
self.layernorm = nn.LayerNorm((embedding_size,))
|
||||
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, token_type_ids, word_embeddings):
|
||||
"""construct of EmbeddingPostprocessor"""
|
||||
"""Postprocessors apply positional and token type embeddings to word embeddings."""
|
||||
output = word_embeddings
|
||||
if self.use_token_type:
|
||||
token_type_embeddings, _ = self.token_type_embedding(token_type_ids)
|
||||
token_type_embeddings = self.token_type_embedding(token_type_ids)
|
||||
output = self.add(output, token_type_embeddings)
|
||||
if not self.use_relative_positions:
|
||||
position_embeddings, _ = self.full_position_embedding(self.position_ids)
|
||||
position_embeddings = self.full_position_embedding(self.position_ids)
|
||||
output = self.add(output, position_embeddings)
|
||||
output = self.layernorm(output)
|
||||
output = self.dropout(output)
|
||||
|
@ -254,25 +215,15 @@ class BertOutput(nn.Cell):
|
|||
dropout_prob (float): The dropout probability. Default: 0.1.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
initializer_range=0.02,
|
||||
dropout_prob=0.1,
|
||||
compute_type=mstype.float32,
|
||||
enable_fused_layernorm=False):
|
||||
compute_type=mstype.float32):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = Dense_Thor(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range),
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=None,
|
||||
batch_size=batch_size).to_float(compute_type)
|
||||
self.dense = nn.Dense(in_channels, out_channels,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.dropout = nn.Dropout(1 - dropout_prob)
|
||||
self.dropout_prob = dropout_prob
|
||||
self.add = P.Add()
|
||||
|
@ -280,7 +231,6 @@ class BertOutput(nn.Cell):
|
|||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, hidden_status, input_tensor):
|
||||
"""construct of BertOutput"""
|
||||
output = self.dense(hidden_status)
|
||||
output = self.dropout(output)
|
||||
output = self.add(input_tensor, output)
|
||||
|
@ -296,7 +246,6 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|||
length (int): Length of one dim for the matrix to be generated.
|
||||
max_relative_position (int): Max value of relative position.
|
||||
"""
|
||||
|
||||
def __init__(self, length, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._length = length
|
||||
|
@ -311,7 +260,7 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|||
self.cast = P.Cast()
|
||||
|
||||
def construct(self):
|
||||
"""construct of RelaPosMatrixGenerator"""
|
||||
"""Generates matrix of relative positions between inputs."""
|
||||
range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
|
||||
range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
|
||||
tile_row_out = self.tile(range_vec_row_out, (self._length,))
|
||||
|
@ -341,7 +290,6 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
initializer_range (float): Initialization value of TruncatedNormal.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
length,
|
||||
depth,
|
||||
|
@ -366,10 +314,9 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.matmul = P.BatchMatMul()
|
||||
|
||||
def construct(self):
|
||||
"""construct of RelaPosEmbeddingsGenerator"""
|
||||
"""Generate embedding for each relative position of dimension depth."""
|
||||
relative_positions_matrix_out = self.relative_positions_matrix()
|
||||
|
||||
# Generate embedding for each relative position of dimension depth.
|
||||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
|
@ -392,7 +339,6 @@ class SaturateCast(nn.Cell):
|
|||
src_type (:class:`mindspore.dtype`): The type of the elements of the input tensor. Default: mstype.float32.
|
||||
dst_type (:class:`mindspore.dtype`): The type of the elements of the output tensor. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
|
@ -406,7 +352,6 @@ class SaturateCast(nn.Cell):
|
|||
self.dst_type = dst_type
|
||||
|
||||
def construct(self, x):
|
||||
"""construct of SaturateCast"""
|
||||
out = self.max_op(x, self.tensor_min_type)
|
||||
out = self.min_op(out, self.tensor_max_type)
|
||||
return self.cast(out, self.dst_type)
|
||||
|
@ -417,7 +362,6 @@ class BertAttention(nn.Cell):
|
|||
Apply multi-headed attention from "from_tensor" to "to_tensor".
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input datasets.
|
||||
from_tensor_width (int): Size of last dim of from_tensor.
|
||||
to_tensor_width (int): Size of last dim of to_tensor.
|
||||
from_seq_length (int): Length of from_tensor sequence.
|
||||
|
@ -437,9 +381,7 @@ class BertAttention(nn.Cell):
|
|||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertAttention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
from_tensor_width,
|
||||
to_tensor_width,
|
||||
from_seq_length,
|
||||
|
@ -458,7 +400,6 @@ class BertAttention(nn.Cell):
|
|||
compute_type=mstype.float32):
|
||||
|
||||
super(BertAttention, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.from_seq_length = from_seq_length
|
||||
self.to_seq_length = to_seq_length
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
@ -472,39 +413,21 @@ class BertAttention(nn.Cell):
|
|||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
weight = TruncatedNormal(initializer_range)
|
||||
units = num_attention_heads * size_per_head
|
||||
self.query_layer = Dense_Thor(in_channels=from_tensor_width,
|
||||
out_channels=units,
|
||||
weight_init=weight,
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=query_act,
|
||||
batch_size=batch_size).to_float(compute_type)
|
||||
self.key_layer = Dense_Thor(in_channels=to_tensor_width,
|
||||
out_channels=units,
|
||||
weight_init=weight,
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=key_act,
|
||||
batch_size=batch_size).to_float(compute_type)
|
||||
self.value_layer = Dense_Thor(in_channels=to_tensor_width,
|
||||
out_channels=units,
|
||||
weight_init=weight,
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=value_act,
|
||||
batch_size=batch_size).to_float(compute_type)
|
||||
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (
|
||||
batch_size, to_seq_length, num_attention_heads, size_per_head)
|
||||
self.query_layer = nn.Dense(from_tensor_width,
|
||||
units,
|
||||
activation=query_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.key_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=key_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
self.value_layer = nn.Dense(to_tensor_width,
|
||||
units,
|
||||
activation=value_act,
|
||||
weight_init=weight).to_float(compute_type)
|
||||
|
||||
self.shape_from = (-1, from_seq_length, num_attention_heads, size_per_head)
|
||||
self.shape_to = (-1, to_seq_length, num_attention_heads, size_per_head)
|
||||
|
||||
self.matmul_trans_b = P.BatchMatMul(transpose_b=True)
|
||||
self.multiply = P.Mul()
|
||||
|
@ -513,7 +436,6 @@ class BertAttention(nn.Cell):
|
|||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = -10000.0
|
||||
self.batch_num = batch_size * num_attention_heads
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
self.softmax = nn.Softmax()
|
||||
|
@ -526,9 +448,9 @@ class BertAttention(nn.Cell):
|
|||
self.cast = P.Cast()
|
||||
self.get_dtype = P.DType()
|
||||
if do_return_2d_tensor:
|
||||
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
|
||||
self.shape_return = (-1, num_attention_heads * size_per_head)
|
||||
else:
|
||||
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)
|
||||
self.shape_return = (-1, from_seq_length, num_attention_heads * size_per_head)
|
||||
|
||||
self.cast_compute_type = SaturateCast(dst_type=compute_type)
|
||||
if self.use_relative_positions:
|
||||
|
@ -540,8 +462,7 @@ class BertAttention(nn.Cell):
|
|||
use_one_hot_embeddings=use_one_hot_embeddings)
|
||||
|
||||
def construct(self, from_tensor, to_tensor, attention_mask):
|
||||
"""construct of BertAttention"""
|
||||
# reshape 2d/3d input tensors to 2d
|
||||
"""reshape 2d/3d input tensors to 2d"""
|
||||
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d)
|
||||
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d)
|
||||
query_out = self.query_layer(from_tensor_2d)
|
||||
|
@ -565,7 +486,7 @@ class BertAttention(nn.Cell):
|
|||
# query_layer_r is [F, B * N, H]
|
||||
query_layer_r = self.reshape(query_layer_t,
|
||||
(self.from_seq_length,
|
||||
self.batch_num,
|
||||
-1,
|
||||
self.size_per_head))
|
||||
# key_position_scores is [F, B * N, F|T]
|
||||
key_position_scores = self.matmul_trans_b(query_layer_r,
|
||||
|
@ -573,7 +494,7 @@ class BertAttention(nn.Cell):
|
|||
# key_position_scores_r is [F, B, N, F|T]
|
||||
key_position_scores_r = self.reshape(key_position_scores,
|
||||
(self.from_seq_length,
|
||||
self.batch_size,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.from_seq_length))
|
||||
# key_position_scores_r_t is [B, N, F, F|T]
|
||||
|
@ -609,7 +530,7 @@ class BertAttention(nn.Cell):
|
|||
attention_probs_r = self.reshape(
|
||||
attention_probs_t,
|
||||
(self.from_seq_length,
|
||||
self.batch_num,
|
||||
-1,
|
||||
self.to_seq_length))
|
||||
# value_position_scores is [F, B * N, H]
|
||||
value_position_scores = self.matmul(attention_probs_r,
|
||||
|
@ -617,7 +538,7 @@ class BertAttention(nn.Cell):
|
|||
# value_position_scores_r is [F, B, N, H]
|
||||
value_position_scores_r = self.reshape(value_position_scores,
|
||||
(self.from_seq_length,
|
||||
self.batch_size,
|
||||
-1,
|
||||
self.num_attention_heads,
|
||||
self.size_per_head))
|
||||
# value_position_scores_r_t is [B, N, F, H]
|
||||
|
@ -636,7 +557,6 @@ class BertSelfAttention(nn.Cell):
|
|||
Apply self-attention.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
seq_length (int): Length of input sequence.
|
||||
hidden_size (int): Size of the bert encoder layers.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
|
@ -648,9 +568,7 @@ class BertSelfAttention(nn.Cell):
|
|||
use_relative_positions (bool): Specifies whether to use relative positions. Default: False.
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in BertSelfAttention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
seq_length,
|
||||
hidden_size,
|
||||
num_attention_heads=12,
|
||||
|
@ -659,8 +577,7 @@ class BertSelfAttention(nn.Cell):
|
|||
initializer_range=0.02,
|
||||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
compute_type=mstype.float32,
|
||||
enable_fused_layernorm=False):
|
||||
compute_type=mstype.float32):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError("The hidden size (%d) is not a multiple of the number "
|
||||
|
@ -669,7 +586,6 @@ class BertSelfAttention(nn.Cell):
|
|||
self.size_per_head = int(hidden_size / num_attention_heads)
|
||||
|
||||
self.attention = BertAttention(
|
||||
batch_size=batch_size,
|
||||
from_tensor_width=hidden_size,
|
||||
to_tensor_width=hidden_size,
|
||||
from_seq_length=seq_length,
|
||||
|
@ -688,13 +604,11 @@ class BertSelfAttention(nn.Cell):
|
|||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type,
|
||||
enable_fused_layernorm=enable_fused_layernorm)
|
||||
compute_type=compute_type)
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""construct of BertSelfAttention"""
|
||||
input_tensor = self.reshape(input_tensor, self.shape)
|
||||
attention_output = self.attention(input_tensor, input_tensor, attention_mask)
|
||||
output = self.output(attention_output, input_tensor)
|
||||
|
@ -706,7 +620,6 @@ class BertEncoderCell(nn.Cell):
|
|||
Encoder cells used in BertTransformer.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
hidden_size (int): Size of the bert encoder layers. Default: 768.
|
||||
seq_length (int): Length of input sequence. Default: 512.
|
||||
num_attention_heads (int): Number of attention heads. Default: 12.
|
||||
|
@ -720,9 +633,7 @@ class BertEncoderCell(nn.Cell):
|
|||
hidden_act (str): Activation function. Default: "gelu".
|
||||
compute_type (:class:`mindspore.dtype`): Compute type in attention. Default: mstype.float32.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
hidden_size=768,
|
||||
seq_length=512,
|
||||
num_attention_heads=12,
|
||||
|
@ -733,11 +644,9 @@ class BertEncoderCell(nn.Cell):
|
|||
hidden_dropout_prob=0.1,
|
||||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
enable_fused_layernorm=False):
|
||||
compute_type=mstype.float32):
|
||||
super(BertEncoderCell, self).__init__()
|
||||
self.attention = BertSelfAttention(
|
||||
batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
|
@ -746,27 +655,18 @@ class BertEncoderCell(nn.Cell):
|
|||
initializer_range=initializer_range,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
compute_type=compute_type,
|
||||
enable_fused_layernorm=enable_fused_layernorm)
|
||||
self.intermediate = Dense_Thor(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
weight_init=TruncatedNormal(initializer_range),
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation=hidden_act,
|
||||
batch_size=batch_size).to_float(compute_type)
|
||||
compute_type=compute_type)
|
||||
self.intermediate = nn.Dense(in_channels=hidden_size,
|
||||
out_channels=intermediate_size,
|
||||
activation=hidden_act,
|
||||
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type)
|
||||
self.output = BertOutput(in_channels=intermediate_size,
|
||||
out_channels=hidden_size,
|
||||
initializer_range=initializer_range,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
compute_type=compute_type,
|
||||
enable_fused_layernorm=enable_fused_layernorm)
|
||||
compute_type=compute_type)
|
||||
|
||||
def construct(self, hidden_states, attention_mask):
|
||||
"""construct of BertEncoderCell"""
|
||||
# self-attention
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
# feed construct
|
||||
|
@ -781,7 +681,6 @@ class BertTransformer(nn.Cell):
|
|||
Multi-layer bert transformer.
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size of input dataset.
|
||||
hidden_size (int): Size of the encoder layers.
|
||||
seq_length (int): Length of input sequence.
|
||||
num_hidden_layers (int): Number of hidden layers in encoder cells.
|
||||
|
@ -797,9 +696,7 @@ class BertTransformer(nn.Cell):
|
|||
compute_type (:class:`mindspore.dtype`): Compute type in BertTransformer. Default: mstype.float32.
|
||||
return_all_encoders (bool): Specifies whether to return all encoders. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
hidden_size,
|
||||
seq_length,
|
||||
num_hidden_layers,
|
||||
|
@ -812,15 +709,13 @@ class BertTransformer(nn.Cell):
|
|||
use_relative_positions=False,
|
||||
hidden_act="gelu",
|
||||
compute_type=mstype.float32,
|
||||
return_all_encoders=False,
|
||||
enable_fused_layernorm=False):
|
||||
return_all_encoders=False):
|
||||
super(BertTransformer, self).__init__()
|
||||
self.return_all_encoders = return_all_encoders
|
||||
|
||||
layers = []
|
||||
for _ in range(num_hidden_layers):
|
||||
layer = BertEncoderCell(batch_size=batch_size,
|
||||
hidden_size=hidden_size,
|
||||
layer = BertEncoderCell(hidden_size=hidden_size,
|
||||
seq_length=seq_length,
|
||||
num_attention_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
|
@ -830,18 +725,17 @@ class BertTransformer(nn.Cell):
|
|||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
use_relative_positions=use_relative_positions,
|
||||
hidden_act=hidden_act,
|
||||
compute_type=compute_type,
|
||||
enable_fused_layernorm=enable_fused_layernorm)
|
||||
compute_type=compute_type)
|
||||
layers.append(layer)
|
||||
|
||||
self.layers = nn.CellList(layers)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (-1, hidden_size)
|
||||
self.out_shape = (batch_size, seq_length, hidden_size)
|
||||
self.out_shape = (-1, seq_length, hidden_size)
|
||||
|
||||
def construct(self, input_tensor, attention_mask):
|
||||
"""construct of BertTransformer"""
|
||||
"""Multi-layer bert transformer."""
|
||||
prev_output = self.reshape(input_tensor, self.shape)
|
||||
|
||||
all_encoder_layers = ()
|
||||
|
@ -866,28 +760,15 @@ class CreateAttentionMaskFromInputMask(nn.Cell):
|
|||
Args:
|
||||
config (Class): Configuration for BertModel.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CreateAttentionMaskFromInputMask, self).__init__()
|
||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||
self.input_mask = None
|
||||
|
||||
if not self.input_mask_from_dataset:
|
||||
self.input_mask = initializer(
|
||||
"ones", [config.batch_size, config.seq_length], mstype.int32).init_data()
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = (config.batch_size, 1, config.seq_length)
|
||||
self.broadcast_ones = initializer(
|
||||
"ones", [config.batch_size, config.seq_length, 1], mstype.float32).init_data()
|
||||
self.batch_matmul = P.BatchMatMul()
|
||||
self.shape = (-1, 1, config.seq_length)
|
||||
|
||||
def construct(self, input_mask):
|
||||
"""construct of CreateAttentionMaskFromInputMask"""
|
||||
if not self.input_mask_from_dataset:
|
||||
input_mask = self.input_mask
|
||||
|
||||
attention_mask = self.cast(self.reshape(input_mask, self.shape), mstype.float32)
|
||||
return attention_mask
|
||||
|
||||
|
@ -901,7 +782,6 @@ class BertModel(nn.Cell):
|
|||
is_training (bool): True for training mode. False for eval mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
is_training,
|
||||
|
@ -912,9 +792,6 @@ class BertModel(nn.Cell):
|
|||
config.hidden_dropout_prob = 0.0
|
||||
config.attention_probs_dropout_prob = 0.0
|
||||
|
||||
self.input_mask_from_dataset = config.input_mask_from_dataset
|
||||
self.token_type_ids_from_dataset = config.token_type_ids_from_dataset
|
||||
self.batch_size = config.batch_size
|
||||
self.seq_length = config.seq_length
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
@ -922,23 +799,14 @@ class BertModel(nn.Cell):
|
|||
self.token_type_ids = None
|
||||
|
||||
self.last_idx = self.num_hidden_layers - 1
|
||||
output_embedding_shape = [self.batch_size, self.seq_length,
|
||||
self.embedding_size]
|
||||
output_embedding_shape = [-1, self.seq_length, self.embedding_size]
|
||||
|
||||
if not self.token_type_ids_from_dataset:
|
||||
self.token_type_ids = initializer(
|
||||
"zeros", [self.batch_size, self.seq_length], mstype.int32).init_data()
|
||||
|
||||
self.bert_embedding_lookup = Embedding_Thor(
|
||||
self.bert_embedding_lookup = nn.Embedding(
|
||||
vocab_size=config.vocab_size,
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
use_one_hot_embeddings=use_one_hot_embeddings,
|
||||
initializer_range=config.initializer_range,
|
||||
batch_size=batch_size,
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency)
|
||||
use_one_hot=use_one_hot_embeddings,
|
||||
embedding_table=TruncatedNormal(config.initializer_range))
|
||||
|
||||
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
|
||||
embedding_size=self.embedding_size,
|
||||
embedding_shape=output_embedding_shape,
|
||||
|
@ -951,7 +819,6 @@ class BertModel(nn.Cell):
|
|||
dropout_prob=config.hidden_dropout_prob)
|
||||
|
||||
self.bert_encoder = BertTransformer(
|
||||
batch_size=self.batch_size,
|
||||
hidden_size=self.hidden_size,
|
||||
seq_length=self.seq_length,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
|
@ -964,8 +831,7 @@ class BertModel(nn.Cell):
|
|||
use_relative_positions=config.use_relative_positions,
|
||||
hidden_act=config.hidden_act,
|
||||
compute_type=config.compute_type,
|
||||
return_all_encoders=True,
|
||||
enable_fused_layernorm=config.enable_fused_layernorm)
|
||||
return_all_encoders=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
|
@ -973,25 +839,16 @@ class BertModel(nn.Cell):
|
|||
self.slice = P.StridedSlice()
|
||||
|
||||
self.squeeze_1 = P.Squeeze(axis=1)
|
||||
self.dense = Dense_Thor(in_channels=self.hidden_size,
|
||||
out_channels=self.hidden_size,
|
||||
weight_init=TruncatedNormal(config.initializer_range),
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=frequency,
|
||||
activation="tanh",
|
||||
batch_size=batch_size).to_float(config.compute_type)
|
||||
self.dense = nn.Dense(self.hidden_size, self.hidden_size,
|
||||
activation="tanh",
|
||||
weight_init=TruncatedNormal(config.initializer_range)).to_float(config.compute_type)
|
||||
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)
|
||||
|
||||
def construct(self, input_ids, token_type_ids, input_mask):
|
||||
"""construct of BertModel"""
|
||||
|
||||
"""Bidirectional Encoder Representations from Transformers."""
|
||||
# embedding
|
||||
if not self.token_type_ids_from_dataset:
|
||||
token_type_ids = self.token_type_ids
|
||||
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
|
||||
embedding_tables = self.bert_embedding_lookup.embedding_table
|
||||
word_embeddings = self.bert_embedding_lookup(input_ids)
|
||||
embedding_output = self.bert_embedding_postprocessor(token_type_ids,
|
||||
word_embeddings)
|
||||
|
||||
|
@ -1005,9 +862,10 @@ class BertModel(nn.Cell):
|
|||
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
|
||||
|
||||
# pooler
|
||||
batch_size = P.Shape()(input_ids)[0]
|
||||
sequence_slice = self.slice(sequence_output,
|
||||
(0, 0, 0),
|
||||
(self.batch_size, 1, self.hidden_size),
|
||||
(batch_size, 1, self.hidden_size),
|
||||
(1, 1, 1))
|
||||
first_token = self.squeeze_1(sequence_slice)
|
||||
pooled_output = self.dense(first_token)
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
Including two kinds of network: \
|
||||
base: Goole BERT-base(the base version of BERT model).
|
||||
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from .bert_model import BertConfig
|
||||
from .config import cfg
|
||||
|
||||
if cfg.bert_network == 'base':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'nezha':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'large':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=True
|
||||
)
|
|
@ -16,15 +16,88 @@
|
|||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from .bert_model import BertConfig
|
||||
cfg = edict({
|
||||
'batch_size': 12,
|
||||
'bert_network': 'large',
|
||||
'optimizer': 'Thor',
|
||||
'Thor': edict({
|
||||
'lr_max': 0.0034,
|
||||
'lr_min': 3.244e-5,
|
||||
'lr_power': 1.0,
|
||||
'lr_total_steps': 30000,
|
||||
'damping_max': 5e-2,
|
||||
'damping_min': 1e-6,
|
||||
'damping_power': 1.0,
|
||||
'damping_total_steps': 30000,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'loss_scale': 1,
|
||||
'loss_scale': 1.0,
|
||||
'frequency': 100,
|
||||
'batch_size': 12,
|
||||
}),
|
||||
})
|
||||
|
||||
'''
|
||||
Including two kinds of network: \
|
||||
base: Google BERT-base(the base version of BERT model).
|
||||
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||
'''
|
||||
if cfg.bert_network == 'base':
|
||||
cfg.batch_size = 64
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'nezha':
|
||||
cfg.batch_size = 96
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'large':
|
||||
cfg.batch_size = 12
|
||||
bert_net_cfg = BertConfig(
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
|
|
|
@ -20,7 +20,7 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from .bert_net_config import bert_net_cfg
|
||||
from .config import cfg
|
||||
|
||||
|
||||
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
|
@ -47,7 +47,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
|
|||
data_set = data_set.map(operations=type_cast_op, input_columns="input_mask")
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="input_ids")
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
|
||||
logger.info("data size: {}".format(data_set.get_dataset_size()))
|
||||
logger.info("repeat count: {}".format(data_set.get_repeat_count()))
|
||||
return data_set
|
||||
|
|
|
@ -1,176 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
|
||||
|
||||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
||||
According to different context, change the iter of dataset, to use the same for loop in different context.
|
||||
|
||||
Note:
|
||||
The iter of DatasetHelper will give one epoch data.
|
||||
|
||||
Args:
|
||||
dataset (DataSet): The training dataset iterator.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
>>> for inputs in dataset_helper:
|
||||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0):
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
Validator.check_is_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("enable_ge"):
|
||||
iterclass = _DatasetIterGE
|
||||
else:
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
iterclass = _DatasetIterPSLite
|
||||
else:
|
||||
iterclass = _DatasetIterMS
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
|
||||
else:
|
||||
iterclass = _DatasetIterNormal
|
||||
self.iter = iterclass(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter.__iter__()
|
||||
|
||||
# A temp solution for loop sink. Delete later
|
||||
def types_shapes(self):
|
||||
"""Get the types and shapes from dataset on current config."""
|
||||
return self.iter.types_shapes()
|
||||
|
||||
def sink_size(self):
|
||||
"""Get sink_size for every iteration."""
|
||||
return self.iter.get_sink_size()
|
||||
|
||||
def stop_send(self):
|
||||
"""Free up resources about data sink."""
|
||||
self.iter.stop_send()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
self.dataset = dataset
|
||||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
self.index = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.sink_count:
|
||||
raise StopIteration()
|
||||
self.index += 1
|
||||
return self.op()
|
||||
|
||||
def types_shapes(self):
|
||||
return self.dataset_types, self.dataset_shapes
|
||||
|
||||
def get_sink_count(self, dataset, sink_size, iter_first_order):
|
||||
sink_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
sink_count = int(sink_size / loop_size) * 2
|
||||
return sink_count
|
||||
|
||||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
else:
|
||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
|
||||
if self.sink_size > 0:
|
||||
sink_size = self.sink_size
|
||||
else:
|
||||
sink_size = self.dataset.get_dataset_size()
|
||||
return sink_size
|
||||
|
||||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context, the device_target is Ascend."""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = self.get_sink_count(dataset, sink_size, iter_first_order)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
def op():
|
||||
return tuple()
|
||||
|
||||
self.op = op
|
|
@ -30,10 +30,10 @@ cfg = edict({
|
|||
'finetune_ckpt': '',
|
||||
'use_crf': False,
|
||||
'clue_benchmark': False,
|
||||
'batch_size': 12,
|
||||
})
|
||||
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=8 if not cfg.clue_benchmark else 1,
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
|
@ -47,8 +47,6 @@ bert_net_cfg = BertConfig(
|
|||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
||||
|
|
|
@ -1,185 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""grad_reducer_thor"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
||||
_all_reduce_G = AllReduce()
|
||||
|
||||
|
||||
def _init_optimizer_allreduce(group):
|
||||
global _all_reduce_G
|
||||
_all_reduce_G = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||
_all_reduce_G.add_prim_attr('fusion', group)
|
||||
|
||||
|
||||
@reduce_opt.register("Function", "Number", "Tensor")
|
||||
def _tensors_allreduce_mean(mul, degree, grad):
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
grad = _all_reduce_G(grad)
|
||||
cast_op = P.Cast()
|
||||
return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tensor")
|
||||
def _tensors_allreduce(allreduce_filter, grad):
|
||||
if allreduce_filter:
|
||||
return _all_reduce_G(grad)
|
||||
return grad
|
||||
|
||||
|
||||
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
||||
|
||||
|
||||
@_get_datatype.register("Tensor")
|
||||
def _tensors_get_datatype(grad):
|
||||
"""
|
||||
Acquire gradient datatype.
|
||||
|
||||
Args:
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
mstype, the datatype of gradient.
|
||||
"""
|
||||
return F.dtype(grad)
|
||||
|
||||
|
||||
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
||||
|
||||
|
||||
@_cast_datatype.register("TypeType", "Tensor")
|
||||
def _tensors_cast_datatype(datatype, grad):
|
||||
"""
|
||||
Cast gradient to datatype.
|
||||
|
||||
Args:
|
||||
datatype (mstype): the destination datatype of gradient.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
return F.cast(grad, datatype)
|
||||
|
||||
|
||||
class DistributedGradReducerThor(Cell):
|
||||
"""
|
||||
A distributed optimizer.
|
||||
|
||||
Constructs a gradient reducer Cell, which applies communication and average operations on
|
||||
single-process gradient values.
|
||||
|
||||
Args:
|
||||
parameters (list): the parameters to be updated.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
|
||||
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If degree is not a int or less than 0.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.communication import init, get_group_size
|
||||
>>> from mindspore.ops import composite as C
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore import ParameterTuple
|
||||
>>> from mindspore.context import ParallelMode
|
||||
>>>
|
||||
>>> device_id = int(os.environ["DEVICE_ID"])
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
|
||||
>>> device_id=int(device_id), enable_hccl=True)
|
||||
>>> init()
|
||||
>>> context.reset_auto_parallel_context()
|
||||
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
>>>
|
||||
>>>
|
||||
>>> class TrainingWrapper(nn.Cell):
|
||||
>>> def __init__(self, network, optimizer, sens=1.0):
|
||||
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
>>> self.network = network
|
||||
>>> self.network.add_flags(defer_inline=True)
|
||||
>>> self.weights = ParameterTuple(network.trainable_params())
|
||||
>>> self.optimizer = optimizer
|
||||
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
>>> self.sens = sens
|
||||
>>> self.reducer_flag = False
|
||||
>>> self.grad_reducer = None
|
||||
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
|
||||
>>> ParallelMode.HYBRID_PARALLEL]:
|
||||
>>> self.reducer_flag = True
|
||||
>>> if self.reducer_flag:
|
||||
>>> mean = context.get_auto_parallel_context("gradients_mean")
|
||||
>>> if mean.get_device_num_is_set():
|
||||
>>> degree = context.get_auto_parallel_context("device_num")
|
||||
>>> else:
|
||||
>>> degree = get_group_size()
|
||||
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
>>>
|
||||
>>> def construct(self, *args):
|
||||
>>> weights = self.weights
|
||||
>>> loss = self.network(*args)
|
||||
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
>>> grads = self.grad(self.network, weights)(*args, sens)
|
||||
>>> if self.reducer_flag:
|
||||
>>> # apply grad reducer on grads
|
||||
>>> grads = self.grad_reducer(grads)
|
||||
>>> return F.depend(loss, self.optimizer(grads))
|
||||
>>>
|
||||
>>> network = Net()
|
||||
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> train_cell = TrainingWrapper(network, optimizer)
|
||||
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
>>> grads = train_cell(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, group, mean=True, degree=None):
|
||||
super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.mul = P.Mul()
|
||||
if degree is None:
|
||||
self.degree = get_group_size()
|
||||
else:
|
||||
if not isinstance(degree, int) or degree <= 0:
|
||||
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
|
||||
self.degree = degree
|
||||
self.mean = mean
|
||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||
_init_optimizer_allreduce(group)
|
||||
|
||||
def construct(self, grads):
|
||||
"""construct of DistributedGradReducerThor"""
|
||||
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
||||
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
|
||||
# and cast back after the operation.
|
||||
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
|
||||
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
|
||||
if self.mean:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
|
||||
else:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
|
||||
|
||||
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
|
||||
return new_grad
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_steps(int): number of warmup epochs
|
||||
total_steps(int): total epoch of training
|
||||
poly_power(int): poly learning rate power
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max - lr_end) * (base ** poly_power)
|
||||
lr = lr + lr_end
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = global_step
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
# bert thor hyperparam setting
|
||||
def get_bert_lr():
|
||||
learning_rate = Tensor(
|
||||
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000,
|
||||
poly_power=1))
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_damping():
|
||||
damping = Tensor(
|
||||
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000,
|
||||
poly_power=1))
|
||||
return damping
|
|
@ -1,773 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model."""
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._checkparam import check_input_data, check_output_data, Validator
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.parallel._utils import _need_to_full
|
||||
from mindspore.train import amp
|
||||
from mindspore.parallel._utils import _to_full_tensor
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.context import ParallelMode
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
Convert from numpy type to tensor type.
|
||||
|
||||
Args:
|
||||
types (list): Numpy type list of element in dataset.
|
||||
|
||||
Returns:
|
||||
list, list of element in dataset.
|
||||
"""
|
||||
ms_types = []
|
||||
for np_type in types:
|
||||
ms_type = pytype_to_dtype(np_type)
|
||||
ms_types.append(ms_type)
|
||||
return ms_types
|
||||
|
||||
|
||||
def _get_types_and_shapes(dataset):
|
||||
"""Get dataset types and shapes."""
|
||||
dataset_types = _convert_type(dataset.output_types())
|
||||
dataset_shapes = dataset.output_shapes()
|
||||
return dataset_types, dataset_shapes
|
||||
|
||||
|
||||
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||
"""Initialize and execute the dataset graph."""
|
||||
batch_size = exec_dataset.get_batch_size()
|
||||
input_indexs = exec_dataset.input_indexs
|
||||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
dataset_shapes,
|
||||
input_indexs,
|
||||
phase=phase,
|
||||
need_run=False)
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
High-Level API for Training or Testing.
|
||||
|
||||
`Model` groups layers into an object with training and inference features.
|
||||
|
||||
Args:
|
||||
network (Cell): The training or testing network.
|
||||
loss_fn (Cell): Objective function, if loss_fn is None, the
|
||||
network should contain the logic of loss and grads calculation, and the logic
|
||||
of parallel if needed. Default: None.
|
||||
optimizer (Cell): Optimizer for updating the weights. Default: None.
|
||||
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
|
||||
training and testing. eg: {'accuracy', 'recall'}. Default: None.
|
||||
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
||||
`eval_network`. Default: None.
|
||||
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
|
||||
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
|
||||
elements, representing the positions of loss value, predict value and label, the loss
|
||||
value would be passed to `Loss` metric, predict value and label would be passed to other
|
||||
metric. Default: None.
|
||||
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
|
||||
precision training. Supports [O0, O2, O3]. Default: "O0".
|
||||
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||
|
||||
O2 is recommended on GPU, O3 is recommended on Ascend.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
>>> dataset = get_dataset()
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._process_amp_args(kwargs)
|
||||
self._parallel_mode = _get_parallel_mode()
|
||||
self._device_number = _get_device_num()
|
||||
self._global_rank = _get_global_rank()
|
||||
self._parameter_broadcast = _get_parameter_broadcast()
|
||||
self._frequency = frequency
|
||||
self._stop_epoch = stop_epoch
|
||||
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level in ["O0", "O3"]:
|
||||
self._keep_bn_fp32 = False
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
self._loss_scale_manager = kwargs['loss_scale_manager']
|
||||
self._loss_scale_manager_set = True
|
||||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
|
||||
def _build_train_network(self):
|
||||
"""Build train network"""
|
||||
network = self._network
|
||||
if self._optimizer:
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
"""Build the network for evaluation."""
|
||||
self._metric_fns = get_metrics(metrics)
|
||||
if not self._metric_fns:
|
||||
return
|
||||
|
||||
if eval_network is not None:
|
||||
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
|
||||
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
|
||||
must be three. But got {}".format(eval_indexes))
|
||||
|
||||
self._eval_network = eval_network
|
||||
self._eval_indexes = eval_indexes
|
||||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if self._optimizer:
|
||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
for metric in self._metric_fns.values():
|
||||
metric.clear()
|
||||
|
||||
def _update_metrics(self, outputs):
|
||||
"""Update metrics local values."""
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError("The `outputs` is not tuple.")
|
||||
|
||||
if self._eval_indexes is not None and len(outputs) < 3:
|
||||
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
|
||||
but got {}".format(len(outputs)))
|
||||
|
||||
for metric in self._metric_fns.values():
|
||||
if self._eval_indexes is None:
|
||||
metric.update(*outputs)
|
||||
else:
|
||||
if isinstance(metric, Loss):
|
||||
metric.update(outputs[self._eval_indexes[0]])
|
||||
else:
|
||||
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
|
||||
|
||||
def _get_metrics(self):
|
||||
"""Get metrics local values."""
|
||||
metrics = dict()
|
||||
for key, value in self._metric_fns.items():
|
||||
metrics[key] = value.eval()
|
||||
return metrics
|
||||
|
||||
def _get_scaling_sens(self):
|
||||
"""get the scaling sens"""
|
||||
scaling_sens = 1
|
||||
if self._loss_scale_manager is not None:
|
||||
scaling_sens = self._loss_scale_manager.get_loss_scale()
|
||||
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
|
||||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
|
||||
iter_first_order=9):
|
||||
"""Initializes dataset."""
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
def init(self, train_dataset=None, valid_dataset=None):
|
||||
"""
|
||||
Initializes compute graphs and data graphs with sink mode.
|
||||
|
||||
Note:
|
||||
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
|
||||
initialized. Default: None.
|
||||
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
|
||||
be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> train_dataset = get_train_dataset()
|
||||
>>> valid_dataset = get_valid_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
||||
>>> model.init(train_dataset, valid_dataset)
|
||||
>>> model.train(2, train_dataset)
|
||||
>>> model.eval(valid_dataset)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||
|
||||
if not train_dataset and not valid_dataset:
|
||||
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
if train_dataset:
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
self._train_network.set_train()
|
||||
self._train_network.phase = 'train'
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._train_network = train_network
|
||||
for inputs in train_dataset_helper:
|
||||
self._train_network.compile(*inputs)
|
||||
break
|
||||
|
||||
if valid_dataset:
|
||||
if not self._metric_fns:
|
||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset.__no_send__ = True
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
for inputs in valid_dataset_helper:
|
||||
self._eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) will
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
epoch = Validator.check_positive_int(epoch)
|
||||
self._train_network.set_train()
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
if dataset_sink_mode and sink_size > 0:
|
||||
cb_params.batch_num = sink_size
|
||||
else:
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
cb_params.mode = "train"
|
||||
cb_params.loss_fn = self._loss_fn
|
||||
cb_params.optimizer = self._optimizer
|
||||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
cb_params.network = self._network
|
||||
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
|
||||
|
||||
@staticmethod
|
||||
def _transform_callbacks(callbacks):
|
||||
"""Transform callback to a list."""
|
||||
if callbacks is None:
|
||||
return []
|
||||
|
||||
if isinstance(callbacks, Iterable):
|
||||
return list(callbacks)
|
||||
|
||||
return [callbacks]
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
|
||||
iter_first_order = self._frequency - 1
|
||||
iter_second_order = 1
|
||||
train_dataset.__loop_size__ = iter_second_order
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
sink_size=sink_size,
|
||||
epoch_num=epoch_num,
|
||||
iter_first_order=iter_first_order)
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
has_do_dataset_init = False
|
||||
switch_branch_one = True
|
||||
train_network_init_flag = True
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if _need_to_full():
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
else:
|
||||
cb_params.cur_step_num += iter_first_order
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
train_network_init_flag = False
|
||||
self._train_network.phase = 'train1'
|
||||
if not has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
dataset_helper.stop_send()
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=False)
|
||||
cb_params.cur_step_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
len_element = len(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
_, overflow, _ = outputs
|
||||
overflow = np.all(overflow.asnumpy())
|
||||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training API where the iteration is controlled by python front-end.
|
||||
|
||||
When setting pynative mode, the training process will be performed with dataset not sink.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
|
||||
operation in dataset processing. Otherwise, errors could occur since the amount of data
|
||||
is not the amount training requires.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch.
|
||||
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
Validator.check_is_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
cb_params.eval_network = self._eval_network
|
||||
list_callback.begin(run_context)
|
||||
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
outputs = self._eval_network(*inputs)
|
||||
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
|
||||
return metrics
|
||||
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
dataset_helper, _ = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=False)
|
||||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
valid_dataset.reset()
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
return metrics
|
||||
|
||||
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Evaluation API where the iteration is controlled by python front-end.
|
||||
|
||||
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed
|
||||
while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.network = self._network
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
||||
self._clear_metrics()
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
Data could be single tensor, or list of tensor, tuple of tensor.
|
||||
|
||||
Note:
|
||||
Batch data should be put together in one tensor.
|
||||
|
||||
Args:
|
||||
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
|
||||
|
||||
Returns:
|
||||
Tensor, array(s) of predictions.
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
||||
>>> model = Model(Net())
|
||||
>>> model.predict(input_data)
|
||||
"""
|
||||
self._predict_network.set_train(False)
|
||||
check_input_data(*predict_data, data_class=Tensor)
|
||||
result = self._predict_network(*predict_data)
|
||||
|
||||
check_output_data(result)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["Model"]
|
|
@ -1,355 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
||||
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
class THOR(Optimizer):
|
||||
"""THOR"""
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0,
|
||||
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32))
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.matrix_A = ParameterTuple(matrix_A)
|
||||
self.matrix_G = ParameterTuple(matrix_G)
|
||||
self.matmul = P.MatMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.Gather()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.sqrt = P.Sqrt()
|
||||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.thor = True
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.expand = P.ExpandDims()
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.batch_size = batch_size
|
||||
self.damping = damping
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct of THOR"""
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
encoder_layers_num = 16
|
||||
if self.thor:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a_ori = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a_ori = F.depend(temp_a_ori, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.expand(temp_a_ori, 1)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# cls1 fully connect layer for masked language model(mlm)
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
# add bert.cls1.output_bias grad
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
else:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.expand(temp_a, 1)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# cls1 fully connect layer for masked language model(mlm)
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
# add bert.cls1.output_bias grad
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
return success
|
|
@ -1,362 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
||||
from .grad_reducer_thor import DistributedGradReducerThor
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
||||
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
class THOR(Optimizer):
|
||||
"""THOR"""
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0,
|
||||
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32))
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.matrix_A = ParameterTuple(matrix_A)
|
||||
self.matrix_G = ParameterTuple(matrix_G)
|
||||
self.matmul = P.MatMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.Gather()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.sqrt = P.Sqrt()
|
||||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.thor = True
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.expand = P.ExpandDims()
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.batch_size = batch_size
|
||||
self.damping = damping
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree)
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct of THOR"""
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
encoder_layers_num = 16
|
||||
if self.thor:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a_ori = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a_ori = F.depend(temp_a_ori, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.expand(temp_a_ori, 1)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# cls1 fully connect layer for masked language model(mlm)
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = F.depend(temp_a, g)
|
||||
temp_g = F.depend(temp_g, g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
# add bert.cls1.output_bias grad
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
else:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.expand(temp_a, 1)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# cls1 fully connect layer for masked language model(mlm)
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
# add bert.cls1.output_bias grad
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
|
||||
length = len(gradients)
|
||||
new_grads = new_grads + gradients[length - 2: length]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
return success
|
|
@ -1,273 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""thor_layer"""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class Embedding_Thor(Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
batch_size=12,
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=100,
|
||||
):
|
||||
super(Embedding_Thor, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]))
|
||||
self.thor = True
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.Gather()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.em_shape = tuple(embedding_shape)
|
||||
self.shape = P.Shape()
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
|
||||
self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float16)), requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)),
|
||||
requires_grad=False)
|
||||
self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16))
|
||||
self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32))
|
||||
self.dampingG = Tensor(np.identity(embedding_size), mstype.float32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.damping = damping
|
||||
self.gather = P.Gather()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.inv = P.Inv()
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
bs = self.batch_size
|
||||
bs = self.cast(bs, mstype.float32)
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, bs)
|
||||
shape = self.shape(dout)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingG = self.cast(self.dampingG, mstype.float32)
|
||||
matrix_G = matrix_G + damping * dampingG
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
|
||||
self.matrix_G_inv = matrix_G_inv
|
||||
return out
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""construct of Embedding_Thor"""
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
if self.thor:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
matrix_A = self.reduce_sum(one_hot_ids, 0)
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingA = self.cast(self.dampingA, mstype.float32)
|
||||
matrix_A = matrix_A + damping * dampingA
|
||||
matrix_A_inv = self.inv(matrix_A)
|
||||
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
|
||||
self.matrix_A_inv = matrix_A_inv
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output_for_reshape = self.getG(output_for_reshape)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
output = self.reshape(output_for_reshape, self.em_shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
class Dense_Thor(Cell):
|
||||
"""Dense_Thor"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=100,
|
||||
has_bias=False,
|
||||
activation=None,
|
||||
batch_size=12):
|
||||
super(Dense_Thor, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
self.has_bias = Validator.check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.ndim != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]))
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.ndim != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]))
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.matrix_A_inv = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float16)),
|
||||
requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)),
|
||||
requires_grad=False)
|
||||
self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16))
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False)
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.damping = damping
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.gather = P.Gather()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.abs = P.Abs()
|
||||
self.reduce_max = P.ReduceMax(keep_dims=False)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
|
||||
self.sqrt = P.Sqrt()
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
bs = self.cast(self.batch_size, mstype.float32)
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, bs)
|
||||
shape = self.shape(dout)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingG = self.cast(self.dampingG, mstype.float32)
|
||||
matrix_G = matrix_G + damping * dampingG
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
|
||||
self.matrix_G_inv = matrix_G_inv
|
||||
return out
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
if self.thor:
|
||||
inputs = self.cube_matmul(x, x)
|
||||
shape = self.shape(x)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_A = self.mul(inputs, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingA = self.cast(self.dampingA, mstype.float32)
|
||||
matrix_A = matrix_A + damping * dampingA
|
||||
matrix_A_inv = self.cholesky(matrix_A)
|
||||
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
|
||||
matrix_A_inv = self.matrix_combine(matrix_A_inv)
|
||||
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
|
||||
self.matrix_A_inv = matrix_A_inv
|
||||
output = self.matmul(x, self.weight)
|
||||
output = self.getG(output)
|
||||
else:
|
||||
output = self.matmul(x, self.weight)
|
||||
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
"""extend_repr"""
|
||||
s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
|
||||
if self.has_bias:
|
||||
s += ', bias={}'.format(self.bias)
|
||||
if self.activation_flag:
|
||||
s += ', activation={}'.format(self.activation)
|
||||
return s
|
|
@ -18,24 +18,22 @@ Functional Cells used in Bert finetune and evaluation.
|
|||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import math
|
||||
import collections
|
||||
import numpy as np
|
||||
from src.config import cfg
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from mindspore import log as logger
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
|
@ -85,7 +83,6 @@ def make_directory(path: str):
|
|||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
@ -95,28 +92,25 @@ class LossCallBack(Callback):
|
|||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
def __init__(self, dataset_size=-1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
self.step_start_time = time.time()
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_start_time = time.time()
|
||||
|
||||
self._dataset_size = dataset_size
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Print loss after each step
|
||||
"""
|
||||
cb_params = run_context.original_args()
|
||||
step_time_span = time.time() - self.step_start_time
|
||||
total_time_span = step_time_span
|
||||
cur_step_num = cb_params.cur_step_num
|
||||
if cur_step_num % cfg.Thor.frequency == 0:
|
||||
step_time_span = step_time_span / (cfg.Thor.frequency - 1)
|
||||
print("epoch: {}, step: {}, outputs are {}, total_time_span is {}, step_time_span is {}".format(
|
||||
cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs), total_time_span, step_time_span))
|
||||
|
||||
if self._dataset_size > 0:
|
||||
percent, epoch_num = math.modf(cb_params.cur_step_num / self._dataset_size)
|
||||
if percent == 0:
|
||||
percent = 1
|
||||
epoch_num -= 1
|
||||
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
|
||||
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)),
|
||||
flush=True)
|
||||
else:
|
||||
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs)), flush=True)
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
|
@ -135,7 +129,7 @@ def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, pre
|
|||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
elif index not in (0, -1):
|
||||
name_split = name_ext[-2].split('_')
|
||||
if (steps_per_epoch != int(name_split[len(name_split) - 1])) \
|
||||
if (steps_per_epoch != int(name_split[len(name_split)-1])) \
|
||||
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
|
||||
continue
|
||||
num = filename[pre_len + 1:pre_len + index]
|
||||
|
@ -149,10 +143,12 @@ class BertLearningRate(LearningRateSchedule):
|
|||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.warmup_flag = False
|
||||
if warmup_steps > 0:
|
||||
self.warmup_flag = True
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
|
@ -161,8 +157,76 @@ class BertLearningRate(LearningRateSchedule):
|
|||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
if self.warmup_flag:
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
else:
|
||||
lr = decay_lr
|
||||
return lr
|
||||
|
||||
|
||||
def convert_labels_to_index(label_list):
|
||||
"""
|
||||
Convert label_list to indices for NER task.
|
||||
"""
|
||||
label2id = collections.OrderedDict()
|
||||
label2id["O"] = 0
|
||||
prefix = ["S_", "B_", "M_", "E_"]
|
||||
index = 0
|
||||
for label in label_list:
|
||||
for pre in prefix:
|
||||
index += 1
|
||||
sub_label = pre + label
|
||||
label2id[sub_label] = index
|
||||
return label2id
|
||||
|
||||
def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): current step
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_steps(int): number of warmup epochs
|
||||
total_steps(int): total epoch of training
|
||||
poly_power(int): poly learning rate power
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max - lr_end) * (base ** poly_power)
|
||||
lr = lr + lr_end
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = global_step
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_thor_lr(lr_max=0.0034, lr_min=3.244e-05, lr_power=1.0, lr_total_steps=30000):
|
||||
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=lr_min, lr_max=lr_max, warmup_steps=0,
|
||||
total_steps=lr_total_steps, poly_power=lr_power)
|
||||
return Tensor(learning_rate)
|
||||
|
||||
|
||||
def get_bert_thor_damping(damping_max=5e-2, damping_min=1e-6, damping_power=1.0, damping_total_steps=30000):
|
||||
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=damping_min, lr_max=damping_max, warmup_steps=0,
|
||||
total_steps=damping_total_steps, poly_power=damping_power)
|
||||
return Tensor(damping)
|
||||
|
|
Loading…
Reference in New Issue