forked from OSSInnovation/mindspore
for bert_thor 1st
for bert_thor 2nd for bert_thor 3rd for bert_thor 4th
This commit is contained in:
parent
16079e6356
commit
1c8eb5910b
|
@ -14,10 +14,11 @@
|
|||
# ============================================================================
|
||||
"""batch_matmul_impl"""
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
from te import tik
|
||||
from topi.cce import util
|
||||
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
cus_batchmatmul_op_info = TBERegOp("CusBatchMatMul") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
|
@ -114,7 +115,8 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr
|
|||
((1, 64, 64), (1, 64, 64), "float32", False, True),
|
||||
((1, 128, 128), (1, 128, 128), "float32", False, True),
|
||||
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
||||
((2, 128, 128), (2, 128, 128), "float32", False, True)]
|
||||
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)]
|
||||
if input_shape not in support_shape:
|
||||
raise RuntimeError("input_shape %s is not supported" % str(input_shape))
|
||||
|
||||
|
@ -232,7 +234,8 @@ def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=Tr
|
|||
((2, 128, 128), (2, 128, 128), "float32", False, True),
|
||||
((4, 128, 128), (4, 128, 128), "float32", False, True),
|
||||
((8, 128, 128), (8, 128, 128), "float32", False, True),
|
||||
((16, 128, 128), (16, 128, 128), "float32", False, True)
|
||||
((16, 128, 128), (16, 128, 128), "float32", False, True),
|
||||
((32, 128, 128), (32, 128, 128), 'float32', False, True)
|
||||
]
|
||||
if input_shape in input_shape_list:
|
||||
block_num = 32
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,93 @@
|
|||
# BERT Example
|
||||
## Description
|
||||
This is an example of training bert by second-order optimizer THOR. THOR is a novel approximate seond-order optimization method in MindSpore.
|
||||
|
||||
## Requirements
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
- Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path.
|
||||
- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc.
|
||||
> Notes:
|
||||
If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train.
|
||||
|
||||
## Running the Example
|
||||
### Pre-Training
|
||||
- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
|
||||
|
||||
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
|
||||
```
|
||||
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
|
||||
|
||||
``` bash
|
||||
sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH
|
||||
```
|
||||
|
||||
## Usage
|
||||
### Pre-Training
|
||||
```
|
||||
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
|
||||
[--enable_save_ckpt ENABLE_SAVE_CKPT]
|
||||
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
|
||||
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
|
||||
[--save_checkpoint_steps N] [--save_checkpoint_num N]
|
||||
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
|
||||
|
||||
options:
|
||||
--distribute pre_training by serveral devices: "true"(training by more than 1 device) | "false", default is "false"
|
||||
--epoch_size epoch size: N, default is 1
|
||||
--device_num number of used devices: N, default is 1
|
||||
--device_id device id: N, default is 0
|
||||
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
|
||||
--enable_lossscale enable lossscale: "true" | "false", default is "true"
|
||||
--do_shuffle enable shuffle: "true" | "false", default is "true"
|
||||
--enable_data_sink enable data sink: "true" | "false", default is "true"
|
||||
--data_sink_steps set data sink steps: N, default is 1
|
||||
--checkpoint_path path to save checkpoint files: PATH, default is ""
|
||||
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
|
||||
--save_checkpoint_num number for saving checkpoint files: N, default is 1
|
||||
--data_dir path to dataset directory: PATH, default is ""
|
||||
--schema_dir path to schema.json file, PATH, default is ""
|
||||
```
|
||||
## Options and Parameters
|
||||
It contains of parameters of BERT model and options for training, which is set in file `config.py`, `bert_net_config.py` and `evaluation_config.py` respectively.
|
||||
### Options:
|
||||
```
|
||||
config.py:
|
||||
bert_network version of BERT model: base | nezha, default is base
|
||||
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor"
|
||||
|
||||
```
|
||||
|
||||
### Parameters:
|
||||
```
|
||||
Parameters for dataset and network (Pre-Training/Evaluation):
|
||||
batch_size batch size of input dataset: N, default is 8
|
||||
seq_length length of input sequence: N, default is 128
|
||||
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136
|
||||
hidden_size size of bert encoder layers: N, default is 768
|
||||
num_hidden_layers number of hidden layers: N, default is 12
|
||||
num_attention_heads number of attention heads: N, default is 12
|
||||
intermediate_size size of intermediate layer: N, default is 3072
|
||||
hidden_act activation function used: ACTIVATION, default is "gelu"
|
||||
hidden_dropout_prob dropout probability for BertOutput: Q, default is 0.1
|
||||
attention_probs_dropout_prob dropout probability for BertAttention: Q, default is 0.1
|
||||
max_position_embeddings maximum length of sequences: N, default is 512
|
||||
type_vocab_size size of token type vocab: N, default is 16
|
||||
initializer_range initialization value of TruncatedNormal: Q, default is 0.02
|
||||
use_relative_positions use relative positions or not: True | False, default is False
|
||||
input_mask_from_dataset use the input mask loaded form dataset or not: True | False, default is True
|
||||
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
|
||||
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
|
||||
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
|
||||
|
||||
Parameters for optimizer:
|
||||
Thor:
|
||||
momentum momentum for the moving average: Q
|
||||
weight_decay weight decay: Q
|
||||
loss_scale loss scale: N
|
||||
frequency the step interval to update second-order information matrix: N, default is 10
|
||||
batch_size batch size of input dataset: N, default is 8
|
||||
```
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Bert evaluation script.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from src import BertModel, GetMaskedLMOutput
|
||||
from src.evaluation_config import cfg, bert_net_cfg
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.metrics import Metric
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
|
||||
class myMetric(Metric):
|
||||
'''
|
||||
Self-defined Metric as a callback.
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
super(myMetric, self).__init__()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.total_num = 0
|
||||
self.acc_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
total_num = self._convert_data(inputs[0])
|
||||
acc_num = self._convert_data(inputs[1])
|
||||
self.total_num = total_num
|
||||
self.acc_num = acc_num
|
||||
|
||||
def eval(self):
|
||||
return self.acc_num / self.total_num
|
||||
|
||||
|
||||
class GetLogProbs(nn.Cell):
|
||||
'''
|
||||
Get MaskedLM prediction scores
|
||||
'''
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetLogProbs, self).__init__()
|
||||
self.bert = BertModel(config, False)
|
||||
self.cls1 = GetMaskedLMOutput(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, masked_pos):
|
||||
sequence_output, _, embedding_table = self.bert(input_ids, token_type_id, input_mask)
|
||||
prediction_scores = self.cls1(sequence_output, embedding_table, masked_pos)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertPretrainEva(nn.Cell):
|
||||
'''
|
||||
Evaluate MaskedLM prediction scores
|
||||
'''
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertPretrainEva, self).__init__()
|
||||
self.bert = GetLogProbs(config)
|
||||
self.argmax = P.Argmax(axis=-1, output_type=mstype.int32)
|
||||
self.equal = P.Equal()
|
||||
self.mean = P.ReduceMean()
|
||||
self.sum = P.ReduceSum()
|
||||
self.total = Parameter(Tensor([0], mstype.float32), name='total')
|
||||
self.acc = Parameter(Tensor([0], mstype.float32), name='acc')
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id, masked_pos, masked_ids, masked_weights, nsp_label):
|
||||
"""construct of BertPretrainEva"""
|
||||
bs, _ = self.shape(input_ids)
|
||||
probs = self.bert(input_ids, input_mask, token_type_id, masked_pos)
|
||||
index = self.argmax(probs)
|
||||
index = self.reshape(index, (bs, -1))
|
||||
eval_acc = self.equal(index, masked_ids)
|
||||
eval_acc1 = self.cast(eval_acc, mstype.float32)
|
||||
real_acc = eval_acc1 * masked_weights
|
||||
acc = self.sum(real_acc)
|
||||
total = self.sum(masked_weights)
|
||||
self.total += total
|
||||
self.acc += acc
|
||||
return acc, self.total, self.acc
|
||||
|
||||
|
||||
def get_enwiki_512_dataset(batch_size=1, repeat_count=1, distribute_file=''):
|
||||
'''
|
||||
Get enwiki seq_length=512 dataset
|
||||
'''
|
||||
ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"masked_lm_positions", "masked_lm_ids",
|
||||
"masked_lm_weights",
|
||||
"next_sentence_labels"])
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def bert_predict():
|
||||
'''
|
||||
Predict function
|
||||
'''
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
||||
dataset = get_enwiki_512_dataset(bert_net_cfg.batch_size, 1)
|
||||
net_for_pretraining = BertPretrainEva(bert_net_cfg)
|
||||
net_for_pretraining.set_train(False)
|
||||
param_dict = load_checkpoint(cfg.finetune_ckpt)
|
||||
load_param_into_net(net_for_pretraining, param_dict)
|
||||
model = Model(net_for_pretraining)
|
||||
return model, dataset, net_for_pretraining
|
||||
|
||||
|
||||
def MLM_eval():
|
||||
'''
|
||||
Evaluate function
|
||||
'''
|
||||
_, dataset, net_for_pretraining = bert_predict()
|
||||
net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2],
|
||||
metrics={'name': myMetric()})
|
||||
res = net.eval(dataset, dataset_sink_mode=False)
|
||||
print("==============================================================")
|
||||
for _, v in res.items():
|
||||
print("Accuracy is: ")
|
||||
print(v)
|
||||
print("==============================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
MLM_eval()
|
|
@ -0,0 +1,202 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
#################pre_train bert example on zh-wiki########################
|
||||
python run_pretrain.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy
|
||||
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from src.bert_net_config import bert_net_cfg
|
||||
from src.config import cfg
|
||||
from src.dataset import create_bert_dataset
|
||||
from src.lr_generator import get_bert_lr, get_bert_damping
|
||||
from src.model_thor import Model
|
||||
# from src.thor_for_bert import THOR
|
||||
from src.thor_for_bert_arg import THOR
|
||||
from src.utils import LossCallBack, BertLearningRate
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.communication.management as D
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def run_pretrain():
|
||||
"""pre-train bert_clue"""
|
||||
parser = argparse.ArgumentParser(description='bert pre_training')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented. (Default: Ascend)')
|
||||
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
|
||||
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
||||
parser.add_argument("--device_id", type=int, default=4, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
|
||||
parser.add_argument("--enable_lossscale", type=str, default="false", help="Use lossscale or not, default is not.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="false", help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default="100", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
||||
"default is 1000.")
|
||||
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
|
||||
"meaning run all steps according to epoch number.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id,
|
||||
save_graphs=True)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(variable_memory_max_size="30GB")
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
if args_opt.device_target == 'Ascend':
|
||||
D.init('hccl')
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
D.init('nccl')
|
||||
device_num = D.get_group_size()
|
||||
rank = D.get_rank()
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
if bert_net_cfg.num_hidden_layers == 12:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
|
||||
"hccl_world_groupsum3")
|
||||
elif bert_net_cfg.num_hidden_layers == 24:
|
||||
if bert_net_cfg.use_relative_positions:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum1")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397],
|
||||
"hccl_world_groupsum3")
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
|
||||
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
|
||||
bert_net_cfg.compute_type = mstype.float32
|
||||
|
||||
ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
|
||||
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
|
||||
|
||||
new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
else:
|
||||
args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size()
|
||||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
if cfg.optimizer == 'Lamb':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
|
||||
end_learning_rate=cfg.Lamb.end_learning_rate,
|
||||
warmup_steps=cfg.Lamb.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.Lamb.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.Lamb.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
|
||||
{'params': other_params},
|
||||
{'order_params': params}]
|
||||
optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
|
||||
elif cfg.optimizer == 'Momentum':
|
||||
optimizer = Momentum(net_with_loss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
|
||||
momentum=cfg.Momentum.momentum)
|
||||
elif cfg.optimizer == 'AdamWeightDecay':
|
||||
lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
||||
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
||||
warmup_steps=cfg.AdamWeightDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=cfg.AdamWeightDecay.power)
|
||||
params = net_with_loss.trainable_params()
|
||||
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
||||
other_params = list(filter(lambda x: x not in decay_params, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
|
||||
optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
||||
elif cfg.optimizer == "Thor":
|
||||
lr = get_bert_lr()
|
||||
damping = get_bert_damping()
|
||||
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
|
||||
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
|
||||
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
|
||||
filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()),
|
||||
filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()),
|
||||
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
|
||||
bert_net_cfg.batch_size, damping)
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
||||
format(cfg.optimizer))
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
||||
if args_opt.enable_save_ckpt == "true":
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
|
||||
if args_opt.load_checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(net_with_loss, param_dict)
|
||||
|
||||
if args_opt.enable_lossscale == "true":
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
|
||||
scale_factor=cfg.scale_factor,
|
||||
scale_window=cfg.scale_window)
|
||||
net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
scale_update_cell=update_cell)
|
||||
else:
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
|
||||
|
||||
model = Model(net_with_grads, frequency=cfg.Thor.frequency)
|
||||
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
numpy.random.seed(0)
|
||||
run_pretrain()
|
|
@ -0,0 +1,62 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: bash run_distribute_pretrain.sh 8 1 /path/zh-wiki/ /path/Schema.json /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
SCHEMA_DIR=$4
|
||||
|
||||
ulimit -u unlimited
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$5
|
||||
export RANK_TABLE_FILE=$5
|
||||
export RANK_SIZE=$1
|
||||
export HCCL_CONNECT_TIMEOUT=300
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$(( $i + 0 ))
|
||||
export RANK_ID=$i
|
||||
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cp -r src ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../run_pretrain.py \
|
||||
--distribute="true" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--device_num=$RANK_SIZE \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="false" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1000 \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=5000 \
|
||||
--save_checkpoint_num=30 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
|
@ -0,0 +1,46 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
|
||||
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATA_DIR=$3
|
||||
SCHEMA_DIR=$4
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_pretrain.py \
|
||||
--distribute="false" \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt="true" \
|
||||
--enable_lossscale="true" \
|
||||
--do_shuffle="true" \
|
||||
--enable_data_sink="true" \
|
||||
--data_sink_steps=1 \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir=$DATA_DIR \
|
||||
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert Init."""
|
||||
from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
||||
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
||||
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
|
||||
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
||||
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
||||
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
||||
SaturateCast, CreateAttentionMaskFromInputMask
|
||||
|
||||
__all__ = [
|
||||
"BertNetworkWithLoss", "BertPreTraining", "BertPretrainingLoss",
|
||||
"GetMaskedLMOutput", "GetNextSentenceOutput", "BertTrainOneStepCell", "BertTrainOneStepWithLossScaleCell",
|
||||
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
||||
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
||||
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator",
|
||||
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
||||
]
|
|
@ -0,0 +1,458 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bert for pretraining."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import _selected_ops
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from .bert_model import BertModel
|
||||
from .config import cfg
|
||||
from .lr_generator import get_bert_damping
|
||||
from .thor_layer import Dense_Thor
|
||||
|
||||
damping = get_bert_damping()
|
||||
loss_scale = cfg.Thor.loss_scale
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
GRADIENT_CLIP_VALUE = 1.0
|
||||
|
||||
clip_grad = C.MultitypeFuncGraph("clip_grad")
|
||||
|
||||
|
||||
# pylint: disable=consider-using-in
|
||||
@clip_grad.register("Number", "Number", "Tensor")
|
||||
def _clip_grad(clip_type, clip_value, grad):
|
||||
"""
|
||||
Clip gradients.
|
||||
|
||||
Inputs:
|
||||
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
|
||||
clip_value (float): Specifies how much to clip.
|
||||
grad (tuple[Tensor]): Gradients.
|
||||
|
||||
Outputs:
|
||||
tuple[Tensor], clipped gradients.
|
||||
"""
|
||||
if clip_type != 0 and clip_type != 1:
|
||||
return grad
|
||||
dt = F.dtype(grad)
|
||||
if clip_type == 0:
|
||||
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
|
||||
F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
else:
|
||||
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
|
||||
return new_grad
|
||||
|
||||
|
||||
class GetMaskedLMOutput(nn.Cell):
|
||||
"""
|
||||
Get masked lm output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, masked lm output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetMaskedLMOutput, self).__init__()
|
||||
self.width = config.hidden_size
|
||||
self.reshape = P.Reshape()
|
||||
self.gather = P.GatherV2()
|
||||
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = Dense_Thor(in_channels=self.width,
|
||||
out_channels=config.hidden_size,
|
||||
weight_init=weight_init,
|
||||
has_bias=True,
|
||||
bias_init='zeros',
|
||||
damping=damping,
|
||||
loss_scale=loss_scale,
|
||||
frequency=1,
|
||||
activation=config.hidden_act,
|
||||
batch_size=config.batch_size).to_float(config.compute_type)
|
||||
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
|
||||
self.output_bias = Parameter(
|
||||
initializer(
|
||||
'zero',
|
||||
config.vocab_size),
|
||||
name='output_bias')
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.log_softmax = nn.LogSoftmax(axis=-1)
|
||||
self.shape_flat_offsets = (-1, 1)
|
||||
self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
|
||||
self.last_idx = (-1,)
|
||||
self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
|
||||
self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
|
||||
self.cast = P.Cast()
|
||||
self.compute_type = config.compute_type
|
||||
self.dtype = config.dtype
|
||||
|
||||
def construct(self,
|
||||
input_tensor,
|
||||
output_weights,
|
||||
positions):
|
||||
"""construct of GetMaskedLMOutput"""
|
||||
flat_offsets = self.reshape(
|
||||
self.rng * self.seq_length_tensor, self.shape_flat_offsets)
|
||||
flat_position = self.reshape(positions + flat_offsets, self.last_idx)
|
||||
flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor)
|
||||
input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
|
||||
input_tensor = self.cast(input_tensor, self.compute_type)
|
||||
output_weights = self.cast(output_weights, self.compute_type)
|
||||
input_tensor = self.dense(input_tensor)
|
||||
input_tensor = self.layernorm(input_tensor)
|
||||
logits = self.matmul(input_tensor, output_weights)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
logits = logits + self.output_bias
|
||||
log_probs = self.log_softmax(logits)
|
||||
return log_probs
|
||||
|
||||
|
||||
class GetNextSentenceOutput(nn.Cell):
|
||||
"""
|
||||
Get next sentence output.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of Bert.
|
||||
|
||||
Returns:
|
||||
Tensor, next sentence output.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(GetNextSentenceOutput, self).__init__()
|
||||
self.log_softmax = _selected_ops.LogSoftmax()
|
||||
weight_init = TruncatedNormal(config.initializer_range)
|
||||
self.dense = nn.Dense(config.hidden_size, 2,
|
||||
weight_init=weight_init, has_bias=True).to_float(config.compute_type)
|
||||
self.dtype = config.dtype
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_tensor):
|
||||
logits = self.dense(input_tensor)
|
||||
logits = self.cast(logits, self.dtype)
|
||||
log_prob = self.log_softmax(logits)
|
||||
return log_prob
|
||||
|
||||
|
||||
class BertPreTraining(nn.Cell):
|
||||
"""
|
||||
Bert pretraining network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
|
||||
|
||||
Returns:
|
||||
Tensor, prediction_scores, seq_relationship_score.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings):
|
||||
super(BertPreTraining, self).__init__()
|
||||
self.bert = BertModel(config, is_training, use_one_hot_embeddings)
|
||||
self.cls1 = GetMaskedLMOutput(config)
|
||||
self.cls2 = GetNextSentenceOutput(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, token_type_id,
|
||||
masked_lm_positions):
|
||||
sequence_output, pooled_output, embedding_table = \
|
||||
self.bert(input_ids, token_type_id, input_mask)
|
||||
prediction_scores = self.cls1(sequence_output,
|
||||
embedding_table,
|
||||
masked_lm_positions)
|
||||
seq_relationship_score = self.cls2(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertPretrainingLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
|
||||
Returns:
|
||||
Tensor, total loss.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(BertPretrainingLoss, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, prediction_scores, seq_relationship_score, masked_lm_ids,
|
||||
masked_lm_weights, next_sentence_labels):
|
||||
"""Defines the computation performed."""
|
||||
label_ids = self.reshape(masked_lm_ids, self.last_idx)
|
||||
label_weights = self.cast(self.reshape(masked_lm_weights, self.last_idx), mstype.float32)
|
||||
one_hot_labels = self.onehot(label_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
|
||||
per_example_loss = self.neg(self.reduce_sum(prediction_scores * one_hot_labels, self.last_idx))
|
||||
numerator = self.reduce_sum(label_weights * per_example_loss, ())
|
||||
denominator = self.reduce_sum(label_weights, ()) + self.cast(F.tuple_to_array((1e-5,)), mstype.float32)
|
||||
masked_lm_loss = numerator / denominator
|
||||
|
||||
# next_sentence_loss
|
||||
labels = self.reshape(next_sentence_labels, self.last_idx)
|
||||
one_hot_labels = self.onehot(labels, 2, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(
|
||||
one_hot_labels * seq_relationship_score, self.last_idx))
|
||||
next_sentence_loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
|
||||
# total_loss
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class BertNetworkWithLoss(nn.Cell):
|
||||
"""
|
||||
Provide bert pre-training loss through network.
|
||||
|
||||
Args:
|
||||
config (BertConfig): The config of BertModel.
|
||||
is_training (bool): Specifies whether to use the training mode.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the loss of the network.
|
||||
"""
|
||||
|
||||
def __init__(self, config, is_training, use_one_hot_embeddings=False):
|
||||
super(BertNetworkWithLoss, self).__init__()
|
||||
self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings)
|
||||
self.loss = BertPretrainingLoss(config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""construct of BertNetworkWithLoss"""
|
||||
prediction_scores, seq_relationship_score = \
|
||||
self.bert(input_ids, input_mask, token_type_id, masked_lm_positions)
|
||||
total_loss = self.loss(prediction_scores, seq_relationship_score,
|
||||
masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return self.cast(total_loss, mstype.float32)
|
||||
|
||||
|
||||
class BertTrainOneStepCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(BertTrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = None
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("mirror_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def set_sens(self, value):
|
||||
self.sens = value
|
||||
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(F.tuple_to_array((self.sens,)),
|
||||
mstype.float32))
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ)
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class BertTrainOneStepWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, scale_update_cell=None):
|
||||
super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad',
|
||||
get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
token_type_id,
|
||||
next_sentence_labels,
|
||||
masked_lm_positions,
|
||||
masked_lm_ids,
|
||||
masked_lm_weights,
|
||||
self.cast(scaling_sens,
|
||||
mstype.float32))
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads)
|
||||
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, cond, scaling_sens)
|
||||
return F.depend(ret, succ)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
Including two kinds of network: \
|
||||
base: Goole BERT-base(the base version of BERT model).
|
||||
large: BERT-NEZHA(a Chinese pretrained language model developed by Huawei, which introduced a improvement of \
|
||||
Functional Relative Posetional Encoding as an effective positional encoding scheme).
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from .bert_model import BertConfig
|
||||
from .config import cfg
|
||||
|
||||
if cfg.bert_network == 'base':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'nezha':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=128,
|
||||
vocab_size=21128,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=True,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
if cfg.bert_network == 'large':
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=cfg.Thor.batch_size,
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=True
|
||||
)
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in dataset.py, run_pretrain.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
cfg = edict({
|
||||
'bert_network': 'large',
|
||||
'loss_scale_value': 65536,
|
||||
'scale_factor': 2,
|
||||
'scale_window': 1000,
|
||||
'optimizer': 'Thor',
|
||||
'AdamWeightDecay': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 1e-10,
|
||||
'power': 5.0,
|
||||
'weight_decay': 1e-5,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
'warmup_steps': 10000,
|
||||
}),
|
||||
'Lamb': edict({
|
||||
'learning_rate': 3e-5,
|
||||
'end_learning_rate': 1e-10,
|
||||
'power': 10.0,
|
||||
'warmup_steps': 10000,
|
||||
'weight_decay': 0.01,
|
||||
'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
|
||||
'eps': 1e-6,
|
||||
}),
|
||||
'Momentum': edict({
|
||||
'learning_rate': 2e-5,
|
||||
'momentum': 0.9,
|
||||
}),
|
||||
'Thor': edict({
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 5e-4,
|
||||
'loss_scale': 1,
|
||||
'frequency': 10,
|
||||
'batch_size': 8,
|
||||
}),
|
||||
})
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in run_pretrain.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine.datasets as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
from .bert_net_config import bert_net_cfg
|
||||
|
||||
|
||||
def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, schema_dir=None):
|
||||
"""create train dataset"""
|
||||
# apply repeat operations
|
||||
files = os.listdir(data_dir)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
if "tfrecord" in file_name:
|
||||
data_files.append(os.path.join(data_dir, file_name))
|
||||
data_files = sorted(data_files)
|
||||
ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
|
||||
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
|
||||
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
|
||||
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print('origin dataset size: ', ori_dataset_size)
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
# apply batch operations
|
||||
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
|
||||
logger.info("data size: {}".format(ds.get_dataset_size()))
|
||||
logger.info("repeat count: {}".format(ds.get_repeat_count()))
|
||||
return ds
|
||||
|
||||
|
||||
def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
|
||||
data_file_path=None, schema_file_path=None):
|
||||
"""create finetune or evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
if assessment_method == "Spearman_correlation":
|
||||
type_cast_op_float = C.TypeCast(mstype.float32)
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
|
||||
else:
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
# apply shuffle operation
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy",
|
||||
data_file_path=None, schema_file_path=None):
|
||||
"""create finetune or evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"])
|
||||
if assessment_method == "Spearman_correlation":
|
||||
type_cast_op_float = C.TypeCast(mstype.float32)
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op_float)
|
||||
else:
|
||||
ds = ds.map(input_columns="label_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
# apply shuffle operation
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True):
|
||||
"""create finetune or evaluation dataset"""
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
if is_training:
|
||||
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids",
|
||||
"start_positions", "end_positions",
|
||||
"unique_ids", "is_impossible"])
|
||||
ds = ds.map(input_columns="start_positions", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="end_positions", operations=type_cast_op)
|
||||
else:
|
||||
ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None,
|
||||
columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"])
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
|
||||
ds = ds.repeat(repeat_count)
|
||||
# apply shuffle operation
|
||||
buffer_size = 960
|
||||
ds = ds.shuffle(buffer_size=buffer_size)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset help for minddata dataset"""
|
||||
import os
|
||||
|
||||
from mindspore import context
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from mindspore.parallel._utils import _get_device_num, _need_to_full
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_full_shapes
|
||||
|
||||
|
||||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
||||
According to different context, change the iter of dataset, to use the same for loop in different context.
|
||||
|
||||
Note:
|
||||
The iter of DatasetHelper will give one epoch data.
|
||||
|
||||
Args:
|
||||
dataset (DataSet): The training dataset iterator.
|
||||
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
>>> for inputs in dataset_helper:
|
||||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=0):
|
||||
check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
if dataset_sink_mode:
|
||||
if context.get_context("enable_ge"):
|
||||
iterclass = _DatasetIterGE
|
||||
else:
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
iterclass = _DatasetIterMSLoopSink
|
||||
elif context.get_context("device_target") == "GPU":
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
iterclass = _DatasetIterPSLite
|
||||
else:
|
||||
iterclass = _DatasetIterMS
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
|
||||
else:
|
||||
iterclass = _DatasetIterNormal
|
||||
self.iter = iterclass(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
return self.iter.__iter__()
|
||||
|
||||
# A temp solution for loop sink. Delete later
|
||||
def types_shapes(self):
|
||||
"""Get the types and shapes from dataset on current config."""
|
||||
return self.iter.types_shapes()
|
||||
|
||||
def sink_size(self):
|
||||
"""Get sink_size for every iteration."""
|
||||
return self.iter.get_sink_size()
|
||||
|
||||
def stop_send(self):
|
||||
"""Free up resources about data sink."""
|
||||
self.iter.stop_send()
|
||||
|
||||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
self.dataset = dataset
|
||||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
self.index = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.sink_count:
|
||||
raise StopIteration()
|
||||
self.index += 1
|
||||
return self.op()
|
||||
|
||||
def types_shapes(self):
|
||||
return self.dataset_types, self.dataset_shapes
|
||||
|
||||
def get_sink_count(self, dataset, sink_size, iter_first_order):
|
||||
sink_count = 1
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
loop_size = dataset.__loop_size__ + iter_first_order
|
||||
sink_count = int(sink_size / loop_size) * 2
|
||||
return sink_count
|
||||
|
||||
def get_sink_size(self):
|
||||
"""get sink_size to device"""
|
||||
sink_size = 1
|
||||
if hasattr(self.dataset, '__loop_size__'):
|
||||
sink_size = self.dataset.__loop_size__
|
||||
else:
|
||||
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
|
||||
if self.sink_size > 0:
|
||||
sink_size = self.sink_size
|
||||
else:
|
||||
sink_size = self.dataset.get_dataset_size()
|
||||
return sink_size
|
||||
|
||||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context (device_target=Ascend)"""
|
||||
|
||||
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = self.get_sink_count(dataset, sink_size, iter_first_order)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
self.sink_count = 1
|
||||
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
|
||||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
def op():
|
||||
return tuple()
|
||||
|
||||
self.op = op
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
config settings, will be used in finetune.py
|
||||
"""
|
||||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from .bert_model import BertConfig
|
||||
|
||||
cfg = edict({
|
||||
'task': 'NER',
|
||||
'num_labels': 41,
|
||||
'data_file': '',
|
||||
'schema_file': None,
|
||||
'finetune_ckpt': '',
|
||||
'use_crf': False,
|
||||
'clue_benchmark': False,
|
||||
})
|
||||
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=8 if not cfg.clue_benchmark else 1,
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""fused layernorm"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.primitive import constexpr
|
||||
|
||||
__all__ = ['FusedLayerNorm']
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_shape_for_norm(x_shape, begin_norm_axis):
|
||||
print("input_shape: ", x_shape)
|
||||
norm_shape = x_shape[begin_norm_axis:]
|
||||
output_shape = (1, -1, 1, int(np.prod(norm_shape)))
|
||||
print("output_shape: ", output_shape)
|
||||
return output_shape
|
||||
|
||||
|
||||
class FusedLayerNorm(Cell):
|
||||
r"""
|
||||
Applies Layer Normalization over a mini-batch of inputs.
|
||||
|
||||
Layer normalization is widely used in recurrent neural networks. It applies
|
||||
normalization over a mini-batch of inputs for each single training case as described
|
||||
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
|
||||
normalization, layer normalization performs exactly the same computation at training and
|
||||
testing times. It can be described using the following formula. It is applied across all channels
|
||||
and pixel but only one batch size.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Args:
|
||||
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
|
||||
`begin_norm_axis ... R - 1`.
|
||||
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
|
||||
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
|
||||
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
|
||||
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
|
||||
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'zeros'.
|
||||
use_batch_nrom (bool): Whether use batchnorm to preocess.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
|
||||
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
|
||||
>>> shape1 = x.shape[1:]
|
||||
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
|
||||
>>> m(x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape,
|
||||
begin_norm_axis=-1,
|
||||
begin_params_axis=-1,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
use_batch_norm=False):
|
||||
super(FusedLayerNorm, self).__init__()
|
||||
if not isinstance(normalized_shape, (tuple, list)):
|
||||
raise TypeError("The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
|
||||
.format(normalized_shape, type(normalized_shape)))
|
||||
self.normalized_shape = normalized_shape
|
||||
self.begin_norm_axis = begin_norm_axis
|
||||
self.begin_params_axis = begin_params_axis
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, normalized_shape), name="gamma")
|
||||
self.beta = Parameter(initializer(
|
||||
beta_init, normalized_shape), name="beta")
|
||||
self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, begin_params_axis=self.begin_params_axis)
|
||||
|
||||
self.batch_norm = P.BatchNorm(is_training=True, epsilon=1e-5)
|
||||
self.use_batch_norm = use_batch_norm
|
||||
|
||||
def construct(self, input_x):
|
||||
"""construct of FusedLayerNorm"""
|
||||
if self.use_batch_norm and self.training:
|
||||
ones = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 1.0)
|
||||
zeros = P.Fill()(mstype.float32, F.shape(input_x)[:self.begin_norm_axis], 0.0)
|
||||
shape_x = F.shape(input_x)
|
||||
norm_shape = get_shape_for_norm(shape_x, self.begin_norm_axis)
|
||||
input_x = F.reshape(input_x, norm_shape)
|
||||
output, _, _, _, _, _ = self.batch_norm(input_x, ones, zeros, None, None)
|
||||
output = F.reshape(output, shape_x)
|
||||
y = output * self.gamma + self.beta
|
||||
else:
|
||||
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
|
||||
return y
|
||||
|
||||
def extend_repr(self):
|
||||
"""Display instance object as string."""
|
||||
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
||||
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
||||
return s
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""grad_reducer_thor"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.communication.management import GlobalComm, get_group_size
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
|
||||
|
||||
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
|
||||
|
||||
_all_reduce_G = AllReduce()
|
||||
|
||||
|
||||
def _init_optimizer_allreduce(group):
|
||||
global _all_reduce_G
|
||||
_all_reduce_G = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
|
||||
_all_reduce_G.add_prim_attr('fusion', group)
|
||||
|
||||
|
||||
@reduce_opt.register("Function", "Number", "Tensor")
|
||||
def _tensors_allreduce_mean(mul, degree, grad):
|
||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||
grad = _all_reduce_G(grad)
|
||||
cast_op = P.Cast()
|
||||
return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
|
||||
|
||||
|
||||
@reduce_opt.register("Bool", "Tensor")
|
||||
def _tensors_allreduce(allreduce_filter, grad):
|
||||
if allreduce_filter:
|
||||
return _all_reduce_G(grad)
|
||||
return grad
|
||||
|
||||
|
||||
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
||||
|
||||
|
||||
@_get_datatype.register("Tensor")
|
||||
def _tensors_get_datatype(grad):
|
||||
"""
|
||||
Acquire gradient datatype.
|
||||
|
||||
Args:
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
mstype, the datatype of gradient.
|
||||
"""
|
||||
return F.dtype(grad)
|
||||
|
||||
|
||||
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
|
||||
|
||||
|
||||
@_cast_datatype.register("TypeType", "Tensor")
|
||||
def _tensors_cast_datatype(datatype, grad):
|
||||
"""
|
||||
Cast gradient to datatype.
|
||||
|
||||
Args:
|
||||
datatype (mstype): the destination datatype of gradient.
|
||||
grad (Tensor): The gradient tensor before operation.
|
||||
|
||||
Returns:
|
||||
Tensor, the gradient tensor after operation.
|
||||
"""
|
||||
return F.cast(grad, datatype)
|
||||
|
||||
|
||||
class DistributedGradReducerThor1(Cell):
|
||||
"""
|
||||
A distributed optimizer.
|
||||
|
||||
Constructs a gradient reducer Cell, which applies communication and average operations on
|
||||
single-process gradient values.
|
||||
|
||||
Args:
|
||||
parameters (list): the parameters to be updated.
|
||||
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
|
||||
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If degree is not a int or less than 0.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.communication import init, get_group_size
|
||||
>>> from mindspore.ops import composite as C
|
||||
>>> from mindspore.ops import operations as P
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore import nn
|
||||
>>> from mindspore import ParallelMode, ParameterTuple
|
||||
>>>
|
||||
>>> device_id = int(os.environ["DEVICE_ID"])
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
|
||||
>>> device_id=int(device_id), enable_hccl=True)
|
||||
>>> init()
|
||||
>>> context.reset_auto_parallel_context()
|
||||
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
|
||||
>>>
|
||||
>>>
|
||||
>>> class TrainingWrapper(nn.Cell):
|
||||
>>> def __init__(self, network, optimizer, sens=1.0):
|
||||
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
>>> self.network = network
|
||||
>>> self.network.add_flags(defer_inline=True)
|
||||
>>> self.weights = ParameterTuple(network.trainable_params())
|
||||
>>> self.optimizer = optimizer
|
||||
>>> self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
>>> self.sens = sens
|
||||
>>> self.reducer_flag = False
|
||||
>>> self.grad_reducer = None
|
||||
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
|
||||
>>> ParallelMode.HYBRID_PARALLEL]:
|
||||
>>> self.reducer_flag = True
|
||||
>>> if self.reducer_flag:
|
||||
>>> mean = context.get_auto_parallel_context("mirror_mean")
|
||||
>>> if mean.get_device_num_is_set():
|
||||
>>> degree = context.get_auto_parallel_context("device_num")
|
||||
>>> else:
|
||||
>>> degree = get_group_size()
|
||||
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
>>>
|
||||
>>> def construct(self, *args):
|
||||
>>> weights = self.weights
|
||||
>>> loss = self.network(*args)
|
||||
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
>>> grads = self.grad(self.network, weights)(*args, sens)
|
||||
>>> if self.reducer_flag:
|
||||
>>> # apply grad reducer on grads
|
||||
>>> grads = self.grad_reducer(grads)
|
||||
>>> return F.depend(loss, self.optimizer(grads))
|
||||
>>>
|
||||
>>> network = Net()
|
||||
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> train_cell = TrainingWrapper(network, optimizer)
|
||||
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
|
||||
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
|
||||
>>> grads = train_cell(inputs, label)
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, group, mean=True, degree=None):
|
||||
super(DistributedGradReducerThor1, self).__init__(auto_prefix=False)
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.mul = P.Mul()
|
||||
if degree is None:
|
||||
self.degree = get_group_size()
|
||||
else:
|
||||
if not isinstance(degree, int) or degree <= 0:
|
||||
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
|
||||
self.degree = degree
|
||||
self.mean = mean
|
||||
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
|
||||
_init_optimizer_allreduce(group)
|
||||
|
||||
def construct(self, grads):
|
||||
"""construct of DistributedGradReducerThor1"""
|
||||
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
|
||||
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
|
||||
# and cast back after the operation.
|
||||
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
|
||||
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)
|
||||
|
||||
if self.mean:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
|
||||
else:
|
||||
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)
|
||||
|
||||
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
|
||||
return new_grad
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_steps(int): number of warmup epochs
|
||||
total_steps(int): total epoch of training
|
||||
poly_power(int): poly learning rate power
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max - lr_end) * (base ** poly_power)
|
||||
lr = lr + lr_end
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
current_step = global_step
|
||||
learning_rate = learning_rate[current_step:]
|
||||
return learning_rate
|
||||
|
||||
|
||||
# bert kfac hyperparam setting
|
||||
def get_bert_lr():
|
||||
learning_rate = Tensor(
|
||||
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000,
|
||||
poly_power=1))
|
||||
return learning_rate
|
||||
|
||||
|
||||
def get_bert_damping():
|
||||
damping = Tensor(
|
||||
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000,
|
||||
poly_power=1))
|
||||
return damping
|
|
@ -0,0 +1,784 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model."""
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
|
||||
import numpy as np
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
from mindspore.parallel._utils import _need_to_full
|
||||
from mindspore.train import amp
|
||||
from mindspore.train._utils import _to_full_tensor
|
||||
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from .dataset_helper import DatasetHelper
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
"""
|
||||
Convert from numpy type to tensor type.
|
||||
|
||||
Args:
|
||||
types (list): Numpy type list of element in dataset.
|
||||
|
||||
Returns:
|
||||
list, list of element in dataset.
|
||||
"""
|
||||
ms_types = []
|
||||
for np_type in types:
|
||||
ms_type = pytype_to_dtype(np_type)
|
||||
ms_types.append(ms_type)
|
||||
return ms_types
|
||||
|
||||
|
||||
def _get_types_and_shapes(dataset):
|
||||
"""Get dataset types and shapes."""
|
||||
dataset_types = _convert_type(dataset.output_types())
|
||||
dataset_shapes = dataset.output_shapes()
|
||||
return dataset_types, dataset_shapes
|
||||
|
||||
|
||||
def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||
"""Initialize and execute the dataset graph."""
|
||||
batch_size = exec_dataset.get_batch_size()
|
||||
input_indexs = exec_dataset.input_indexs
|
||||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
dataset_shapes,
|
||||
input_indexs,
|
||||
phase=phase,
|
||||
need_run=False)
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
High-Level API for Training or Testing.
|
||||
|
||||
`Model` groups layers into an object with training and inference features.
|
||||
|
||||
Args:
|
||||
network (Cell): The training or testing network.
|
||||
loss_fn (Cell): Objective function, if loss_fn is None, the
|
||||
network should contain the logic of loss and grads calculation, and the logic
|
||||
of parallel if needed. Default: None.
|
||||
optimizer (Cell): Optimizer for updating the weights. Default: None.
|
||||
metrics (Union[dict, set]): Dict or set of metrics to be evaluated by the model during
|
||||
training and testing. eg: {'accuracy', 'recall'}. Default: None.
|
||||
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
||||
`eval_network`. Default: None.
|
||||
eval_indexes (list): In case of defining the `eval_network`, if `eval_indexes` is None, all outputs of
|
||||
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
|
||||
elements, representing the positions of loss value, predict value and label, the loss
|
||||
value would be passed to `Loss` metric, predict value and label would be passed to other
|
||||
metric. Default: None.
|
||||
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
|
||||
precision training. Supports [O0, O2, O3]. Default: "O0".
|
||||
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||
|
||||
O2 is recommended on GPU, O3 is recommended on Ascend.
|
||||
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12) # padding=0
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
>>> dataset = get_dataset()
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._process_amp_args(kwargs)
|
||||
self._parallel_mode = _get_parallel_mode()
|
||||
self._device_number = _get_device_num()
|
||||
self._global_rank = _get_global_rank()
|
||||
self._parameter_broadcast = _get_parameter_broadcast()
|
||||
self._frequency = frequency
|
||||
self._stop_epoch = stop_epoch
|
||||
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level in ["O0", "O3"]:
|
||||
self._keep_bn_fp32 = False
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
self._loss_scale_manager = kwargs['loss_scale_manager']
|
||||
self._loss_scale_manager_set = True
|
||||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
raise ValueError(f"Unsupport arg '{arg}'")
|
||||
|
||||
def _build_train_network(self):
|
||||
"""Build train network"""
|
||||
network = self._network
|
||||
if self._optimizer:
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
return network
|
||||
|
||||
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
||||
"""Build the network for evaluation."""
|
||||
self._metric_fns = get_metrics(metrics)
|
||||
if not self._metric_fns:
|
||||
return
|
||||
|
||||
if eval_network is not None:
|
||||
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
|
||||
raise ValueError("Eval_indexes must be a list or None. If eval_indexes is a list, length of it \
|
||||
must be three. But got {}".format(eval_indexes))
|
||||
|
||||
self._eval_network = eval_network
|
||||
self._eval_indexes = eval_indexes
|
||||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
if self._optimizer:
|
||||
self._eval_network = _VirtualDatasetCell(self._eval_network)
|
||||
self._eval_network.set_auto_parallel()
|
||||
|
||||
def _build_predict_network(self):
|
||||
"""Build the network for prediction."""
|
||||
self._predict_network = self._network
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
self._predict_network = _VirtualDatasetCell(self._network)
|
||||
self._predict_network.set_auto_parallel()
|
||||
|
||||
def _clear_metrics(self):
|
||||
"""Clear metrics local values."""
|
||||
for metric in self._metric_fns.values():
|
||||
metric.clear()
|
||||
|
||||
def _update_metrics(self, outputs):
|
||||
"""Update metrics local values."""
|
||||
if not isinstance(outputs, tuple):
|
||||
raise ValueError("The `outputs` is not tuple.")
|
||||
|
||||
if self._eval_indexes is not None and len(outputs) < 3:
|
||||
raise ValueError("The length of `outputs` must be greater than or equal to 3, \
|
||||
but got {}".format(len(outputs)))
|
||||
|
||||
for metric in self._metric_fns.values():
|
||||
if self._eval_indexes is None:
|
||||
metric.update(*outputs)
|
||||
else:
|
||||
if isinstance(metric, Loss):
|
||||
metric.update(outputs[self._eval_indexes[0]])
|
||||
else:
|
||||
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
|
||||
|
||||
def _get_metrics(self):
|
||||
"""Get metrics local values."""
|
||||
metrics = dict()
|
||||
for key, value in self._metric_fns.items():
|
||||
metrics[key] = value.eval()
|
||||
return metrics
|
||||
|
||||
def _get_scaling_sens(self):
|
||||
"""get the scaling sens"""
|
||||
scaling_sens = 1
|
||||
if self._loss_scale_manager is not None:
|
||||
scaling_sens = self._loss_scale_manager.get_loss_scale()
|
||||
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
|
||||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
|
||||
iter_first_order=9):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
# remove later to deal with loop sink
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||
and not context.get_context("enable_ge"):
|
||||
need_wrap = True
|
||||
|
||||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
def init(self, train_dataset=None, valid_dataset=None):
|
||||
"""
|
||||
Initializes compute graphs and data graphs with sink mode.
|
||||
|
||||
Note:
|
||||
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
||||
|
||||
Args:
|
||||
train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be
|
||||
initialized. Default: None.
|
||||
valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will
|
||||
be initialized, and `metrics` in `Model` can not be None. Default: None.
|
||||
|
||||
Examples:
|
||||
>>> train_dataset = get_train_dataset()
|
||||
>>> valid_dataset = get_valid_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'})
|
||||
>>> model.init(train_dataset, valid_dataset)
|
||||
>>> model.train(2, train_dataset)
|
||||
>>> model.eval(valid_dataset)
|
||||
"""
|
||||
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
||||
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
||||
|
||||
if not train_dataset and not valid_dataset:
|
||||
raise ValueError('Both train_dataset and valid_dataset can not be None or empty.')
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
|
||||
if train_dataset:
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
self._train_network.set_train()
|
||||
self._train_network.phase = 'train'
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._train_network = train_network
|
||||
for inputs in train_dataset_helper:
|
||||
self._train_network.compile(*inputs)
|
||||
break
|
||||
|
||||
if valid_dataset:
|
||||
if not self._metric_fns:
|
||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset.__no_send__ = True
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
for inputs in valid_dataset_helper:
|
||||
self._eval_network.compile(*inputs)
|
||||
break
|
||||
|
||||
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) will be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) will
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be executed while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
epoch = check_int_positive(epoch)
|
||||
self._train_network.set_train()
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.epoch_num = epoch
|
||||
if dataset_sink_mode and sink_size > 0:
|
||||
cb_params.batch_num = sink_size
|
||||
else:
|
||||
cb_params.batch_num = train_dataset.get_dataset_size()
|
||||
cb_params.mode = "train"
|
||||
cb_params.loss_fn = self._loss_fn
|
||||
cb_params.optimizer = self._optimizer
|
||||
cb_params.parallel_mode = self._parallel_mode
|
||||
cb_params.device_number = self._device_number
|
||||
cb_params.train_dataset = train_dataset
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.train_dataset_element = None
|
||||
cb_params.network = self._network
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
epoch = 1
|
||||
|
||||
# build callback list
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if not dataset_sink_mode:
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
elif context.get_context("mode") == context.PYNATIVE_MODE:
|
||||
logger.warning("The pynative mode cannot support dataset sink mode currently."
|
||||
"So the training process will be performed with dataset not sink.")
|
||||
self._train_process(epoch, train_dataset, list_callback, cb_params)
|
||||
else:
|
||||
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
|
||||
|
||||
@staticmethod
|
||||
def _transform_callbacks(callbacks):
|
||||
"""Transform callback to a list."""
|
||||
if callbacks is None:
|
||||
return []
|
||||
|
||||
if isinstance(callbacks, Iterable):
|
||||
return list(callbacks)
|
||||
|
||||
return [callbacks]
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
||||
|
||||
iter_first_order = self._frequency - 1
|
||||
iter_second_order = 1
|
||||
train_dataset.__loop_size__ = iter_second_order
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
sink_size=sink_size,
|
||||
epoch_num=epoch_num,
|
||||
iter_first_order=iter_first_order)
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
has_do_dataset_init = False
|
||||
switch_branch_one = True
|
||||
train_network_init_flag = True
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if _need_to_full():
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
list_callback.step_begin(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
self._train_network.phase = 'train0'
|
||||
else:
|
||||
cb_params.cur_step_num += iter_first_order
|
||||
if train_network_init_flag:
|
||||
self._train_network.add_flags_recursive(thor=False)
|
||||
train_network_init_flag = False
|
||||
self._train_network.phase = 'train1'
|
||||
if not has_do_dataset_init:
|
||||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
dataset_helper.stop_send()
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Training process. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
"""
|
||||
dataset_helper, _ = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=False)
|
||||
cb_params.cur_step_num = 0
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
||||
should_stop = False
|
||||
|
||||
for i in range(epoch):
|
||||
cb_params.cur_epoch_num = i + 1
|
||||
|
||||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
len_element = len(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
_, overflow, _ = outputs
|
||||
overflow = np.all(overflow.asnumpy())
|
||||
self._loss_scale_manager.update_loss_scale(overflow)
|
||||
|
||||
list_callback.step_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
train_dataset.reset()
|
||||
|
||||
list_callback.epoch_end(run_context)
|
||||
should_stop = should_stop or run_context.get_stop_requested()
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
list_callback.end(run_context)
|
||||
|
||||
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
|
||||
"""
|
||||
Training API where the iteration is controlled by python front-end.
|
||||
|
||||
When setting pynative mode, the training process will be performed with dataset not sink.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, epoch of training should be equal to the count of repeat
|
||||
operation in dataset processing. Otherwise, errors could occur since the amount of data
|
||||
is not the amount training requires.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
epoch (int): Total number of iterations on the data.
|
||||
train_dataset (Dataset): A training dataset iterator. If there is no
|
||||
loss_fn, a tuple with multiply data (data1, data2, data3, ...) should be
|
||||
returned and passed to the network. Otherwise, a tuple (data, label) should
|
||||
be returned, and the data and label are passed to the network and loss
|
||||
function respectively.
|
||||
callbacks (list): List of callback object. Callbacks which should be excuted while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
Configure pynative mode, the training process will be performed with
|
||||
dataset not sink.
|
||||
sink_size (int): Control the amount of data each sink.
|
||||
If sink_size=-1, sink the complete dataset each epoch.
|
||||
If sink_size>0, sink sink_size data each epoch.
|
||||
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> loss_scale_manager = FixedLossScaleManager()
|
||||
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
|
||||
>>> model.train(2, dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
|
||||
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
|
||||
self._train(epoch,
|
||||
train_dataset,
|
||||
callbacks=callbacks,
|
||||
dataset_sink_mode=dataset_sink_mode,
|
||||
sink_size=sink_size)
|
||||
|
||||
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network through dataset channel.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=True)
|
||||
self._eval_network = eval_network
|
||||
cb_params.eval_network = self._eval_network
|
||||
list_callback.begin(run_context)
|
||||
|
||||
for inputs in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
outputs = self._eval_network(*inputs)
|
||||
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
|
||||
return metrics
|
||||
|
||||
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None):
|
||||
"""
|
||||
Evaluation. The data would be passed to network directly.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
list_callback (Callback): Executor of callback list. Default: None.
|
||||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
list_callback.begin(run_context)
|
||||
|
||||
dataset_helper, _ = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
dataset=valid_dataset,
|
||||
dataset_sink_mode=False)
|
||||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
valid_dataset.reset()
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
return metrics
|
||||
|
||||
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=True):
|
||||
"""
|
||||
Evaluation API where the iteration is controlled by python front-end.
|
||||
|
||||
Configure to pynative mode, the evaluation will be performed with dataset non-sink mode.
|
||||
|
||||
Note:
|
||||
CPU is not supported when dataset_sink_mode is true.
|
||||
If dataset_sink_mode is True, data will be sent to device. If device is Ascend, features
|
||||
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
||||
|
||||
Args:
|
||||
valid_dataset (Dataset): Dataset to evaluate the model.
|
||||
callbacks (list): List of callback object. Callbacks which should be excuted
|
||||
while training. Default: None.
|
||||
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
|
||||
|
||||
Returns:
|
||||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
|
||||
Examples:
|
||||
>>> dataset = get_dataset()
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
||||
>>> model.eval(dataset)
|
||||
"""
|
||||
check_bool(dataset_sink_mode)
|
||||
_device_number_check(self._parallel_mode, self._device_number)
|
||||
if not self._metric_fns:
|
||||
raise ValueError("metric fn can not be None or empty.")
|
||||
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.eval_network = self._eval_network
|
||||
cb_params.valid_dataset = valid_dataset
|
||||
cb_params.batch_num = valid_dataset.get_dataset_size()
|
||||
cb_params.mode = "eval"
|
||||
cb_params.cur_step_num = 0
|
||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.network = self._network
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
||||
self._clear_metrics()
|
||||
|
||||
with _CallbackManager(callbacks) as list_callback:
|
||||
if dataset_sink_mode:
|
||||
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
||||
return self._eval_process(valid_dataset, list_callback, cb_params)
|
||||
|
||||
def predict(self, *predict_data):
|
||||
"""
|
||||
Generates output predictions for the input samples.
|
||||
|
||||
Data could be single tensor, or list of tensor, tuple of tensor.
|
||||
|
||||
Note:
|
||||
Batch data should be put together in one tensor.
|
||||
|
||||
Args:
|
||||
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
|
||||
|
||||
Returns:
|
||||
Tensor, array(s) of predictions.
|
||||
|
||||
Examples:
|
||||
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
||||
>>> model = Model(Net())
|
||||
>>> model.predict(input_data)
|
||||
"""
|
||||
self._predict_network.set_train(False)
|
||||
check_input_data(*predict_data, data_class=Tensor)
|
||||
result = self._predict_network(*predict_data)
|
||||
|
||||
check_output_data(result)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["Model"]
|
|
@ -0,0 +1,422 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
||||
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
class THOR(Optimizer):
|
||||
"""THOR"""
|
||||
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
|
||||
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.matrix_A = ParameterTuple(matrix_A)
|
||||
self.matrix_G = ParameterTuple(matrix_G)
|
||||
self.A_inv_max = ParameterTuple(A_inv_max)
|
||||
self.G_inv_max = ParameterTuple(G_inv_max)
|
||||
self.matmul = P.MatMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.GatherV2()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.matrix_max_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
fc_layer_num = num_hidden_layers * 6 + 5
|
||||
for i in range(fc_layer_num):
|
||||
self.matrix_max_inv = self.matrix_max_inv + (
|
||||
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
|
||||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.thor = True
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.expand = P.ExpandDims()
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.batch_size = batch_size
|
||||
self.damping = damping
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct of THOR"""
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
encoder_layers_num = 16
|
||||
if self.thor:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a_ori = self.matrix_A[matrix_idx]
|
||||
temp_a = self.expand(temp_a_ori, 1)
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
G_max = self.G_inv_max[matrix_idx]
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_G_inv_max = self.log(G_max)
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.cast(g, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, G_max)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# for cls1 fc layer: mlm
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
gradients = new_grads
|
||||
else:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_a = self.expand(temp_a, 1)
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
g = self.mul(temp_a, g)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# for cls1 fc layer: mlm
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
# add bert.cls1.output_bias grad
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
gradients = new_grads
|
||||
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
return success
|
|
@ -0,0 +1,429 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""momentum"""
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
|
||||
from .grad_reducer_thor1 import DistributedGradReducerThor1
|
||||
|
||||
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
||||
@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
|
||||
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
|
||||
"""Apply momentum optimizer to the weight parameter using Tensor."""
|
||||
success = True
|
||||
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
|
||||
return success
|
||||
|
||||
|
||||
op_add = P.AddN()
|
||||
apply_decay = C.MultitypeFuncGraph("apply_decay")
|
||||
|
||||
|
||||
@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
|
||||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
class THOR(Optimizer):
|
||||
"""THOR"""
|
||||
|
||||
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
|
||||
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10,
|
||||
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
|
||||
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.matrix_A = ParameterTuple(matrix_A)
|
||||
self.matrix_G = ParameterTuple(matrix_G)
|
||||
self.A_inv_max = ParameterTuple(A_inv_max)
|
||||
self.G_inv_max = ParameterTuple(G_inv_max)
|
||||
self.matmul = P.MatMul()
|
||||
self.transpose = P.Transpose()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.gather = P.GatherV2()
|
||||
self.matrix_A_inv = ()
|
||||
self.matrix_G_inv = ()
|
||||
self.matrix_max_inv = ()
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
fc_layer_num = num_hidden_layers * 6 + 5
|
||||
for i in range(fc_layer_num):
|
||||
self.matrix_max_inv = self.matrix_max_inv + (
|
||||
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
|
||||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.thor = True
|
||||
self.weight_decay = weight_decay * loss_scale
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.expand = P.ExpandDims()
|
||||
self.square = P.Square()
|
||||
self.inv = P.Inv()
|
||||
self.batch_size = batch_size
|
||||
self.damping = damping
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree)
|
||||
|
||||
def construct(self, gradients):
|
||||
"""construct of THOR"""
|
||||
params = self.params
|
||||
moments = self.moments
|
||||
encoder_layers_num = 16
|
||||
if self.thor:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a_ori = self.matrix_A[matrix_idx]
|
||||
temp_a = self.expand(temp_a_ori, 1)
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
G_max = self.G_inv_max[matrix_idx]
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_G_inv_max = self.log(G_max)
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.cast(g, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, G_max)
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# for cls1 fc layer: mlm
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float32)
|
||||
temp_g = self.cast(temp_g, mstype.float32)
|
||||
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
|
||||
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
|
||||
matrix_A_inv_max = self.exp(matrix_A_inv_max)
|
||||
temp_a = self.mul(temp_a, matrix_A_inv_max)
|
||||
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
|
||||
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
|
||||
matrix_G_inv_max = self.exp(matrix_G_inv_max)
|
||||
temp_g = self.mul(temp_g, matrix_G_inv_max)
|
||||
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
|
||||
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
else:
|
||||
new_grads = ()
|
||||
# process embedding layer
|
||||
for em_idx in range(3):
|
||||
g = gradients[em_idx]
|
||||
matrix_idx = em_idx
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_a = self.expand(temp_a, 1)
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
g = self.mul(temp_a, g)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(g, temp_g)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g,)
|
||||
# process bert_embedding_postprocessor.layernorm
|
||||
grad_idx = 3
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.one
|
||||
damping = self.sqrt(damping_step)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
|
||||
for i in range(self.num_hidden_layers):
|
||||
encoder_begin_idx = encoder_layers_num * i + 5
|
||||
for j in range(0, encoder_layers_num, 2):
|
||||
grad_idx = encoder_begin_idx + j
|
||||
if j in (8, 14):
|
||||
# process layernorm layer
|
||||
beta_grad = gradients[grad_idx]
|
||||
gamma_grad = gradients[grad_idx + 1]
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
beta = self.square(beta_grad)
|
||||
beta_cov = self.mul(beta, 1.0 / normalizer)
|
||||
beta_cov = beta_cov + damping
|
||||
beta_inv = self.inv(beta_cov)
|
||||
gamma = self.square(gamma_grad)
|
||||
gamma_cov = self.mul(gamma, 1.0 / normalizer)
|
||||
gamma_cov = gamma_cov + damping
|
||||
gamma_inv = self.inv(gamma_cov)
|
||||
beta = self.mul(beta_inv, beta_grad)
|
||||
gamma = self.mul(gamma_inv, gamma_grad)
|
||||
new_grads = new_grads + (beta, gamma)
|
||||
else:
|
||||
g = gradients[grad_idx]
|
||||
offset_idx = 0
|
||||
if j in (0, 2, 4, 6):
|
||||
offset_idx = j // 2
|
||||
elif j in (10, 12):
|
||||
offset_idx = j // 2 - 1
|
||||
matrix_idx = 6 * i + offset_idx + 3
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = new_grads + (gradients[grad_idx + 1],)
|
||||
|
||||
# process pooler layer
|
||||
pooler_layer_idx = encoder_layers_num * self.num_hidden_layers + 5
|
||||
matrix_idx = self.num_hidden_layers * 6 + 3
|
||||
g = gradients[pooler_layer_idx]
|
||||
pooler_bias = gradients[pooler_layer_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
new_grads = new_grads + (g, pooler_bias)
|
||||
|
||||
# for cls1 fc layer: mlm
|
||||
mlm_fc_idx = encoder_layers_num * self.num_hidden_layers + 8
|
||||
matrix_idx = self.num_hidden_layers * 6 + 4
|
||||
g = gradients[mlm_fc_idx]
|
||||
mlm_bias = gradients[mlm_fc_idx + 1]
|
||||
temp_a = self.matrix_A[matrix_idx]
|
||||
temp_g = self.matrix_G[matrix_idx]
|
||||
matrix_max = self.matrix_max_inv[matrix_idx]
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
# add bert.cls1.output_bias grad
|
||||
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
|
||||
new_grads = new_grads + (g, mlm_bias)
|
||||
# add bert.cls1.layernorm grad
|
||||
begin_idx = mlm_fc_idx + 2
|
||||
end_idx = mlm_fc_idx + 4
|
||||
new_grads = new_grads + gradients[begin_idx: end_idx]
|
||||
lenth = len(gradients)
|
||||
new_grads = new_grads + gradients[lenth - 2: lenth]
|
||||
gradients = new_grads
|
||||
gradients = self.grad_reducer_g(gradients)
|
||||
|
||||
if self.weight_decay > 0:
|
||||
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
|
||||
params, gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
|
||||
return success
|
|
@ -0,0 +1,304 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""thor_layer"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import check_bool, check_int_positive
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Embedding_Thor(Cell):
|
||||
"""
|
||||
A embeddings lookup table with a fixed dictionary and size.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
|
||||
each embedding vector.
|
||||
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
embedding_size,
|
||||
embedding_shape,
|
||||
use_one_hot_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
name='embedding_table',
|
||||
is_expand=False,
|
||||
batch_size=12,
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=10,
|
||||
):
|
||||
super(Embedding_Thor, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.use_one_hot_embeddings = use_one_hot_embeddings
|
||||
self.embedding_table = Parameter(initializer
|
||||
(TruncatedNormal(initializer_range),
|
||||
[vocab_size, embedding_size]),
|
||||
name=name)
|
||||
self.thor = True
|
||||
self.is_expand = is_expand
|
||||
self.expand = P.ExpandDims()
|
||||
self.shape_flat = (-1,)
|
||||
self.gather = P.GatherV2()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.array_mul = P.MatMul()
|
||||
self.reshape = P.Reshape()
|
||||
self.em_shape = tuple(embedding_shape)
|
||||
self.shape = P.Shape()
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv',
|
||||
requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)),
|
||||
name="matrix_G_inv", requires_grad=False)
|
||||
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
|
||||
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
|
||||
self.fused_abs_max = P.CusFusedAbsMax1()
|
||||
self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16))
|
||||
self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32))
|
||||
self.dampingG = Tensor(np.identity(embedding_size), mstype.float32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.damping = damping
|
||||
self.gather = P.GatherV2()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.inv = P.Inv()
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
bs = self.batch_size
|
||||
bs = self.cast(bs, mstype.float32)
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, bs)
|
||||
shape = self.shape(dout)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingG = self.cast(self.dampingG, mstype.float32)
|
||||
matrix_G = matrix_G + damping * dampingG
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv)
|
||||
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max)
|
||||
self.G_inv_max = matrix_G_inv_max
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
|
||||
self.matrix_G_inv = matrix_G_inv
|
||||
return out
|
||||
|
||||
def construct(self, input_ids):
|
||||
"""construct of Embedding_Thor"""
|
||||
if self.is_expand:
|
||||
input_ids = self.expand(input_ids, -1)
|
||||
flat_ids = self.reshape(input_ids, self.shape_flat)
|
||||
if self.use_one_hot_embeddings:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
|
||||
else:
|
||||
if self.thor:
|
||||
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
|
||||
matrix_A = self.reduce_sum(one_hot_ids, 0)
|
||||
normalizer = self.batch_size
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingA = self.cast(self.dampingA, mstype.float32)
|
||||
matrix_A = matrix_A + damping * dampingA
|
||||
matrix_A_inv = self.inv(matrix_A)
|
||||
self.matrix_A_inv = matrix_A_inv
|
||||
self.matrix_G_inv = self.fake_G
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
output_for_reshape = self.getG(output_for_reshape)
|
||||
else:
|
||||
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
|
||||
|
||||
output = self.reshape(output_for_reshape, self.em_shape)
|
||||
return output, self.embedding_table
|
||||
|
||||
|
||||
class Dense_Thor(Cell):
|
||||
"""Dense_Thor"""
|
||||
|
||||
# @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels'])
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
damping=0.03,
|
||||
loss_scale=1,
|
||||
frequency=10,
|
||||
has_bias=False,
|
||||
activation=None,
|
||||
batch_size=12):
|
||||
super(Dense_Thor, self).__init__()
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = check_bool(has_bias)
|
||||
self.thor = True
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.matrix_A_inv = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float16)),
|
||||
name='matrix_A_inv', requires_grad=False)
|
||||
self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)),
|
||||
name="matrix_G_inv", requires_grad=False)
|
||||
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
|
||||
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
|
||||
self.fused_abs_max = P.CusFusedAbsMax1()
|
||||
self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16))
|
||||
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.damping = damping
|
||||
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.gather = P.GatherV2()
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.freq = Tensor(frequency, mstype.int32)
|
||||
self.axis = 0
|
||||
self.abs = P.Abs()
|
||||
self.reduce_max = P.ReduceMax(keep_dims=False)
|
||||
self.log = P.Log()
|
||||
self.exp = P.Exp()
|
||||
self.dampingA = Tensor(np.identity(in_channels), mstype.float32)
|
||||
self.dampingG = Tensor(np.identity(out_channels), mstype.float32)
|
||||
self.sqrt = P.Sqrt()
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""save_gradient"""
|
||||
bs = self.cast(self.batch_size, mstype.float32)
|
||||
out = dout
|
||||
dout = self.mul(dout, self.loss_scale)
|
||||
dout = self.mul(dout, bs)
|
||||
shape = self.shape(dout)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_G = self.cube_matmul(dout, dout)
|
||||
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
|
||||
damping_step = self.gather(self.damping, self.cov_step, 0)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
self.cov_step = self.cov_step + self.freq
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingG = self.cast(self.dampingG, mstype.float32)
|
||||
matrix_G = matrix_G + damping * dampingG
|
||||
matrix_G_inv = self.cholesky(matrix_G)
|
||||
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
|
||||
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv)
|
||||
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max)
|
||||
self.G_inv_max = matrix_G_inv_max
|
||||
matrix_G_inv = self.matrix_combine(matrix_G_inv)
|
||||
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
|
||||
self.matrix_G_inv = matrix_G_inv
|
||||
return out
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
if self.thor:
|
||||
inputs = self.cube_matmul(x, x)
|
||||
shape = self.shape(x)
|
||||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_A = self.mul(inputs, 1.0 / normalizer)
|
||||
|
||||
damping_step = self.gather(self.damping, self.cov_step, self.axis)
|
||||
damping_step = self.cast(damping_step, mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
dampingA = self.cast(self.dampingA, mstype.float32)
|
||||
matrix_A = matrix_A + damping * dampingA
|
||||
matrix_A_inv = self.cholesky(matrix_A)
|
||||
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
|
||||
matrix_A_inv_max = self.fused_abs_max(matrix_A_inv)
|
||||
matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max)
|
||||
self.A_inv_max = matrix_A_inv_max
|
||||
matrix_A_inv = self.matrix_combine(matrix_A_inv)
|
||||
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
|
||||
self.matrix_A_inv = matrix_A_inv
|
||||
self.matrix_G_inv = self.fake_G
|
||||
output = self.matmul(x, self.weight)
|
||||
output = self.getG(output)
|
||||
else:
|
||||
output = self.matmul(x, self.weight)
|
||||
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
if self.activation_flag:
|
||||
return self.activation(output)
|
||||
return output
|
||||
|
||||
def extend_repr(self):
|
||||
"""extend_repr"""
|
||||
str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}' \
|
||||
.format(self.in_channels, self.out_channels, self.weight, self.has_bias)
|
||||
if self.has_bias:
|
||||
str_info = str_info + ', bias={}'.format(self.bias)
|
||||
|
||||
if self.activation_flag:
|
||||
str_info = str_info + ', activation={}'.format(self.activation)
|
||||
|
||||
return str_info
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Functional Cells used in Bert finetune and evaluation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from src.config import cfg
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class CrossEntropyCalculation(nn.Cell):
|
||||
"""
|
||||
Cross Entropy loss
|
||||
"""
|
||||
|
||||
def __init__(self, is_training=True):
|
||||
super(CrossEntropyCalculation, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
self.reshape = P.Reshape()
|
||||
self.last_idx = (-1,)
|
||||
self.neg = P.Neg()
|
||||
self.cast = P.Cast()
|
||||
self.is_training = is_training
|
||||
|
||||
def construct(self, logits, label_ids, num_labels):
|
||||
if self.is_training:
|
||||
label_ids = self.reshape(label_ids, self.last_idx)
|
||||
one_hot_labels = self.onehot(label_ids, num_labels, self.on_value, self.off_value)
|
||||
per_example_loss = self.neg(self.reduce_sum(one_hot_labels * logits, self.last_idx))
|
||||
loss = self.reduce_mean(per_example_loss, self.last_idx)
|
||||
return_value = self.cast(loss, mstype.float32)
|
||||
else:
|
||||
return_value = logits * 1.0
|
||||
return return_value
|
||||
|
||||
|
||||
def make_directory(path: str):
|
||||
"""Make directory."""
|
||||
if path is None or not isinstance(path, str) or path.strip() == "":
|
||||
logger.error("The path(%r) is invalid type.", path)
|
||||
raise TypeError("Input path is invaild type")
|
||||
|
||||
# convert the relative paths
|
||||
path = os.path.realpath(path)
|
||||
logger.debug("The abs path is %r", path)
|
||||
|
||||
# check the path is exist and write permissions?
|
||||
if os.path.exists(path):
|
||||
real_path = path
|
||||
else:
|
||||
# All exceptions need to be caught because create directory maybe have some limit(permissions)
|
||||
logger.debug("The directory(%s) doesn't exist, will create it", path)
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
real_path = path
|
||||
except PermissionError as e:
|
||||
logger.error("No write permission on the directory(%r), error = %r", path, e)
|
||||
raise TypeError("No write permission on the directory.")
|
||||
return real_path
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss in NAN or INF terminating training.
|
||||
Note:
|
||||
if per_print_times is 0 do not print loss.
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0")
|
||||
self._per_print_times = per_print_times
|
||||
self.step_start_time = time.time()
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_start_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
step_time_span = time.time() - self.step_start_time
|
||||
total_time_span = step_time_span
|
||||
cur_step_num = cb_params.cur_step_num
|
||||
if cur_step_num % cfg.Thor.frequency == 0:
|
||||
step_time_span = step_time_span / (cfg.Thor.frequency - 1)
|
||||
print("epoch: {}, step: {}, outputs are {}, total_time_span is {}, step_time_span is {}".format(
|
||||
cb_params.cur_epoch_num, cb_params.cur_step_num,
|
||||
str(cb_params.net_outputs), total_time_span, step_time_span))
|
||||
|
||||
|
||||
def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
|
||||
"""
|
||||
Find the ckpt finetune generated and load it into eval network.
|
||||
"""
|
||||
files = os.listdir(load_finetune_checkpoint_dir)
|
||||
pre_len = len(prefix)
|
||||
max_num = 0
|
||||
for filename in files:
|
||||
name_ext = os.path.splitext(filename)
|
||||
if name_ext[-1] != ".ckpt":
|
||||
continue
|
||||
# steps_per_epoch = ds.get_dataset_size()
|
||||
if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
|
||||
index = filename[pre_len:].find("-")
|
||||
if index == 0 and max_num == 0:
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
elif index not in (0, -1):
|
||||
name_split = name_ext[-2].split('_')
|
||||
if (steps_per_epoch != int(name_split[len(name_split) - 1])) \
|
||||
or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])):
|
||||
continue
|
||||
num = filename[pre_len + 1:pre_len + index]
|
||||
if int(num) > max_num:
|
||||
max_num = int(num)
|
||||
load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename)
|
||||
return load_finetune_checkpoint_path
|
||||
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for Bert network.
|
||||
"""
|
||||
|
||||
def __init__(self, learning_rate, end_learning_rate, warmup_steps, decay_steps, power):
|
||||
super(BertLearningRate, self).__init__()
|
||||
self.warmup_lr = WarmUpLR(learning_rate, warmup_steps)
|
||||
self.decay_lr = PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
|
||||
self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
|
||||
|
||||
self.greater = P.Greater()
|
||||
self.one = Tensor(np.array([1.0]).astype(np.float32))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, global_step):
|
||||
is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
|
||||
warmup_lr = self.warmup_lr(global_step)
|
||||
decay_lr = self.decay_lr(global_step)
|
||||
lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
|
||||
return lr
|
Loading…
Reference in New Issue