modify the codes for bert_thor

2nd for update thor bert
This commit is contained in:
zongha 2020-08-14 17:05:42 +08:00
parent 05b03fe017
commit da142ccfd3
13 changed files with 138 additions and 275 deletions

View File

@ -11,14 +11,14 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no
## Running the Example
### Pre-Training
- Set options in `config.py`, including lossscale, optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
- Set options in `config.py`, including optimizer and network. Click [here](https://www.mindspore.cn/tutorial/zh-CN/master/use/data_preparation/loading_the_datasets.html#tfrecord) for more information about dataset and the json schema file.
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base and BERT-NEZHA model.
- Run `run_standalone_pretrain.sh` for non-distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model.
``` bash
sh scripts/run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR
```
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base and BERT-NEZHA model.
- Run `run_distribute_pretrain.sh` for distributed pre-training of BERT-base, BERT-NEZHA and BERT-large model.
``` bash
sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR RANK_TABLE_FILE
@ -30,7 +30,7 @@ This is an example of training bert by second-order optimizer THOR. THOR is a no
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--save_checkpoint_path CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR]
@ -44,7 +44,7 @@ options:
--do_shuffle enable shuffle: "true" | "false", default is "true"
--enable_data_sink enable data sink: "true" | "false", default is "true"
--data_sink_steps set data sink steps: N, default is 1
--checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_path path to save checkpoint files: PATH, default is ""
--save_checkpoint_steps steps for saving checkpoint files: N, default is 1000
--save_checkpoint_num number for saving checkpoint files: N, default is 1
--data_dir path to dataset directory: PATH, default is ""
@ -55,7 +55,7 @@ It contains of parameters of BERT model and options for training, which is set i
### Options:
```
config.py:
bert_network version of BERT model: base | nezha, default is base
bert_network version of BERT model: base | nezha | large, default is large
optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum | Thor, default is "Thor"
```
@ -63,7 +63,7 @@ config.py:
### Parameters:
```
Parameters for dataset and network (Pre-Training/Evaluation):
batch_size batch size of input dataset: N, default is 8
batch_size batch size of input dataset: N, default is 12
seq_length length of input sequence: N, default is 128
vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136
hidden_size size of bert encoder layers: N, default is 768
@ -87,7 +87,7 @@ Parameters for optimizer:
momentum momentum for the moving average: Q
weight_decay weight decay: Q
loss_scale loss scale: N
frequency the step interval to update second-order information matrix: N, default is 10
batch_size batch size of input dataset: N, default is 8
frequency the step interval to update second-order information matrix: N, default is 100
batch_size batch size of input dataset: N, default is 12
```

View File

@ -19,7 +19,6 @@ python run_pretrain.py
import argparse
import os
import numpy
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.bert_net_config import bert_net_cfg
@ -27,10 +26,8 @@ from src.config import cfg
from src.dataset import create_bert_dataset
from src.lr_generator import get_bert_lr, get_bert_damping
from src.model_thor import Model
# from src.thor_for_bert import THOR
from src.thor_for_bert_arg import THOR
from src.utils import LossCallBack, BertLearningRate
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore import context
@ -69,8 +66,8 @@ def run_pretrain():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id,
save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
device_id=args_opt.device_id, save_graphs=False)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(variable_memory_max_size="30GB")
ckpt_save_dir = args_opt.save_checkpoint_path
@ -165,15 +162,13 @@ def run_pretrain():
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net_with_loss.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net_with_loss.get_parameters()),
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
bert_net_cfg.batch_size, damping)
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
format(cfg.optimizer))
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
if args_opt.enable_save_ckpt == "true":
if args_opt.enable_save_ckpt == "true" and rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)

View File

@ -37,25 +37,26 @@ do
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cp -r src ./LOG$i
cp ../*.py ./LOG$i
cp -r ../src ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python ../run_pretrain.py \
python run_pretrain.py \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--do_shuffle="false" \
--enable_data_sink="true" \
--data_sink_steps=1000 \
--load_checkpoint_path="" \
--save_checkpoint_steps=5000 \
--save_checkpoint_path='./' \
--save_checkpoint_steps=1000 \
--save_checkpoint_num=30 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
cd ../
done
done

View File

@ -20,27 +20,39 @@ echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
ulimit -u unlimited
export DEVICE_ID=$1
export RANK_SIZE=1
if [ -d "LOG" ];
then
rm -rf ./LOG
fi
mkdir ./LOG
cp ../*.py ./LOG
cp -r ../src ./LOG
cd ./LOG || exit
echo "start training for device $DEVICE_ID"
env > env.log
python run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="false" \
--enable_data_sink="true" \
--data_sink_steps=1000 \
--load_checkpoint_path="" \
--save_checkpoint_path='./' \
--save_checkpoint_steps=5000 \
--save_checkpoint_num=20 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
cd ../

View File

@ -35,6 +35,8 @@ from .thor_layer import Dense_Thor
damping = get_bert_damping()
loss_scale = cfg.Thor.loss_scale
frequency = cfg.Thor.frequency
batch_size = cfg.Thor.batch_size
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
@ -91,9 +93,9 @@ class GetMaskedLMOutput(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=config.hidden_act,
batch_size=config.batch_size).to_float(config.compute_type)
batch_size=batch_size).to_float(config.compute_type)
self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
self.output_bias = Parameter(
initializer(

View File

@ -34,6 +34,7 @@ from .thor_layer import Dense_Thor, Embedding_Thor
damping = get_bert_damping()
loss_scale = cfg.Thor.loss_scale
frequency = cfg.Thor.frequency
batch_size = cfg.Thor.batch_size
@ -200,11 +201,10 @@ class EmbeddingPostprocessor(nn.Cell):
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
name='embedding_table',
is_expand=False,
batch_size=batch_size,
damping=damping,
loss_scale=loss_scale,
frequency=1)
frequency=frequency)
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
@ -225,11 +225,10 @@ class EmbeddingPostprocessor(nn.Cell):
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
name='full_position_embeddings',
is_expand=False,
batch_size=batch_size,
damping=damping,
loss_scale=loss_scale,
frequency=1)
frequency=frequency)
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.layernorm = nn.LayerNorm((embedding_size,))
@ -274,7 +273,7 @@ class BertOutput(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=None,
batch_size=batch_size).to_float(compute_type)
self.dropout = nn.Dropout(1 - dropout_prob)
@ -488,7 +487,7 @@ class BertAttention(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=query_act,
batch_size=batch_size).to_float(compute_type)
self.key_layer = Dense_Thor(in_channels=to_tensor_width,
@ -498,7 +497,7 @@ class BertAttention(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=key_act,
batch_size=batch_size).to_float(compute_type)
self.value_layer = Dense_Thor(in_channels=to_tensor_width,
@ -508,7 +507,7 @@ class BertAttention(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=value_act,
batch_size=batch_size).to_float(compute_type)
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
@ -764,7 +763,7 @@ class BertEncoderCell(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation=hidden_act,
batch_size=batch_size).to_float(compute_type)
self.output = BertOutput(in_channels=intermediate_size,
@ -945,11 +944,10 @@ class BertModel(nn.Cell):
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range,
name='embedding_table',
is_expand=True,
batch_size=batch_size,
damping=damping,
loss_scale=loss_scale,
frequency=1)
frequency=frequency)
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
@ -991,7 +989,7 @@ class BertModel(nn.Cell):
bias_init='zeros',
damping=damping,
loss_scale=loss_scale,
frequency=1,
frequency=frequency,
activation="tanh",
batch_size=batch_size).to_float(config.compute_type)
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask(config)

View File

@ -19,9 +19,6 @@ from easydict import EasyDict as edict
cfg = edict({
'bert_network': 'large',
'loss_scale_value': 65536,
'scale_factor': 2,
'scale_window': 1000,
'optimizer': 'Thor',
'AdamWeightDecay': edict({
'learning_rate': 3e-5,
@ -49,7 +46,7 @@ cfg = edict({
'momentum': 0.9,
'weight_decay': 5e-4,
'loss_scale': 1,
'frequency': 10,
'batch_size': 8,
'frequency': 100,
'batch_size': 12,
}),
})

View File

@ -16,7 +16,6 @@
Data operations, will be used in run_pretrain.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
@ -37,7 +36,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"],
shuffle=de.Shuffle.FILES if do_shuffle == "true" else False,
num_shards=device_num, shard_id=rank, shard_equal_rows=True)
num_shards=device_num, shard_id=rank, shard_equal_rows=False)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
type_cast_op = C.TypeCast(mstype.int32)

View File

@ -80,7 +80,7 @@ def _tensors_cast_datatype(datatype, grad):
return F.cast(grad, datatype)
class DistributedGradReducerThor1(Cell):
class DistributedGradReducerThor(Cell):
"""
A distributed optimizer.
@ -154,7 +154,7 @@ class DistributedGradReducerThor1(Cell):
"""
def __init__(self, parameters, group, mean=True, degree=None):
super(DistributedGradReducerThor1, self).__init__(auto_prefix=False)
super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap()
self.mul = P.Mul()
if degree is None:
@ -168,7 +168,7 @@ class DistributedGradReducerThor1(Cell):
_init_optimizer_allreduce(group)
def construct(self, grads):
"""construct of DistributedGradReducerThor1"""
"""construct of DistributedGradReducerThor"""
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.

View File

@ -58,7 +58,7 @@ def get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps,
# bert kfac hyperparam setting
def get_bert_lr():
learning_rate = Tensor(
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=4e-4, warmup_steps=0, total_steps=30000,
get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.1e-3, warmup_steps=0, total_steps=30000,
poly_power=1))
return learning_rate

View File

@ -46,9 +46,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
class THOR(Optimizer):
"""THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10,
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0,
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
@ -60,8 +59,6 @@ class THOR(Optimizer):
self.opt = P.ApplyMomentum()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.matmul = P.MatMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
@ -70,16 +67,8 @@ class THOR(Optimizer):
self.gather = P.GatherV2()
self.matrix_A_inv = ()
self.matrix_G_inv = ()
self.matrix_max_inv = ()
self.num_hidden_layers = num_hidden_layers
fc_layer_num = num_hidden_layers * 6 + 5
for i in range(fc_layer_num):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
@ -90,7 +79,6 @@ class THOR(Optimizer):
self.inv = P.Inv()
self.batch_size = batch_size
self.damping = damping
self.freq = Tensor(frequency, mstype.int32)
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
@ -106,26 +94,20 @@ class THOR(Optimizer):
g = gradients[em_idx]
matrix_idx = em_idx
temp_a_ori = self.matrix_A[matrix_idx]
temp_a = self.expand(temp_a_ori, 1)
temp_g = self.matrix_G[matrix_idx]
G_max = self.G_inv_max[matrix_idx]
temp_g = self.cast(temp_g, mstype.float32)
matrix_G_inv_max = self.log(G_max)
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
g = self.mul(temp_a, g)
g = self.cast(g, mstype.float16)
temp_a_ori = F.depend(temp_a_ori, g)
temp_g = F.depend(temp_g, g)
temp_a = self.expand(temp_a_ori, 1)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.mul(temp_a, g)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
g = self.mul(g, G_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g,)
# process bert_embedding_postprocessor.layernorm
grad_idx = 3
@ -180,32 +162,18 @@ class THOR(Optimizer):
matrix_idx = 6 * i + offset_idx + 3
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g,)
new_grads = new_grads + (gradients[grad_idx + 1],)
@ -216,32 +184,18 @@ class THOR(Optimizer):
pooler_bias = gradients[pooler_layer_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g, pooler_bias)
# for cls1 fc layer: mlm
@ -251,38 +205,26 @@ class THOR(Optimizer):
mlm_bias = gradients[mlm_fc_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
# add bert.cls1.output_bias grad
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
new_grads = new_grads + (g, mlm_bias)
# add bert.cls1.layernorm grad
begin_idx = mlm_fc_idx + 2
end_idx = mlm_fc_idx + 4
new_grads = new_grads + gradients[begin_idx: end_idx]
lenth = len(gradients)
new_grads = new_grads + gradients[lenth - 2: lenth]
gradients = new_grads
@ -293,15 +235,16 @@ class THOR(Optimizer):
g = gradients[em_idx]
matrix_idx = em_idx
temp_a = self.matrix_A[matrix_idx]
temp_a = self.expand(temp_a, 1)
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
g = self.mul(temp_a, g)
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.expand(temp_a, 1)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.mul(temp_a, g)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g,)
# process bert_embedding_postprocessor.layernorm
grad_idx = 3
@ -356,15 +299,14 @@ class THOR(Optimizer):
matrix_idx = 6 * i + offset_idx + 3
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g,)
new_grads = new_grads + (gradients[grad_idx + 1],)
@ -375,15 +317,14 @@ class THOR(Optimizer):
pooler_bias = gradients[pooler_layer_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g, pooler_bias)
# for cls1 fc layer: mlm
@ -393,15 +334,14 @@ class THOR(Optimizer):
mlm_bias = gradients[mlm_fc_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
# add bert.cls1.output_bias grad
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
new_grads = new_grads + (g, mlm_bias)
@ -409,6 +349,7 @@ class THOR(Optimizer):
begin_idx = mlm_fc_idx + 2
end_idx = mlm_fc_idx + 4
new_grads = new_grads + gradients[begin_idx: end_idx]
lenth = len(gradients)
new_grads = new_grads + gradients[lenth - 2: lenth]
gradients = new_grads

View File

@ -21,7 +21,7 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from .grad_reducer_thor1 import DistributedGradReducerThor1
from .grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -48,9 +48,8 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
class THOR(Optimizer):
"""THOR"""
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03, frequency=10,
def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, weight_decay=0.0,
loss_scale=1.0, num_hidden_layers=24, batch_size=12, damping=0.03,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()):
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
@ -62,8 +61,6 @@ class THOR(Optimizer):
self.opt = P.ApplyMomentum()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.matmul = P.MatMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
@ -72,16 +69,8 @@ class THOR(Optimizer):
self.gather = P.GatherV2()
self.matrix_A_inv = ()
self.matrix_G_inv = ()
self.matrix_max_inv = ()
self.num_hidden_layers = num_hidden_layers
fc_layer_num = num_hidden_layers * 6 + 5
for i in range(fc_layer_num):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
@ -92,12 +81,11 @@ class THOR(Optimizer):
self.inv = P.Inv()
self.batch_size = batch_size
self.damping = damping
self.freq = Tensor(frequency, mstype.int32)
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer_g = DistributedGradReducerThor1(self.parameters, 3, mean, degree)
self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree)
def construct(self, gradients):
"""construct of THOR"""
@ -111,26 +99,20 @@ class THOR(Optimizer):
g = gradients[em_idx]
matrix_idx = em_idx
temp_a_ori = self.matrix_A[matrix_idx]
temp_a = self.expand(temp_a_ori, 1)
temp_g = self.matrix_G[matrix_idx]
G_max = self.G_inv_max[matrix_idx]
temp_g = self.cast(temp_g, mstype.float32)
matrix_G_inv_max = self.log(G_max)
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
g = self.mul(temp_a, g)
g = self.cast(g, mstype.float16)
temp_a_ori = F.depend(temp_a_ori, g)
temp_g = F.depend(temp_g, g)
temp_a = self.expand(temp_a_ori, 1)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.mul(temp_a, g)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
g = self.mul(g, G_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a_ori)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], G_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g,)
# process bert_embedding_postprocessor.layernorm
grad_idx = 3
@ -185,32 +167,18 @@ class THOR(Optimizer):
matrix_idx = 6 * i + offset_idx + 3
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g,)
new_grads = new_grads + (gradients[grad_idx + 1],)
@ -221,32 +189,18 @@ class THOR(Optimizer):
pooler_bias = gradients[pooler_layer_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (g, pooler_bias)
# for cls1 fc layer: mlm
@ -256,38 +210,26 @@ class THOR(Optimizer):
mlm_bias = gradients[mlm_fc_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(self.A_inv_max[matrix_idx])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(self.G_inv_max[matrix_idx])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(self.A_inv_max[matrix_idx], self.G_inv_max[matrix_idx])
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
# add bert.cls1.output_bias grad
fake_A = self.assign(self.matrix_A[matrix_idx], temp_a)
fake_G = self.assign(self.matrix_G[matrix_idx], temp_g)
fake_max = self.assign(self.matrix_max_inv[matrix_idx], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
new_grads = new_grads + (g, mlm_bias)
# add bert.cls1.layernorm grad
begin_idx = mlm_fc_idx + 2
end_idx = mlm_fc_idx + 4
new_grads = new_grads + gradients[begin_idx: end_idx]
lenth = len(gradients)
new_grads = new_grads + gradients[lenth - 2: lenth]
gradients = new_grads
@ -299,15 +241,16 @@ class THOR(Optimizer):
g = gradients[em_idx]
matrix_idx = em_idx
temp_a = self.matrix_A[matrix_idx]
temp_a = self.expand(temp_a, 1)
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
g = self.mul(temp_a, g)
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.expand(temp_a, 1)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.mul(temp_a, g)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g,)
# process bert_embedding_postprocessor.layernorm
grad_idx = 3
@ -362,15 +305,14 @@ class THOR(Optimizer):
matrix_idx = 6 * i + offset_idx + 3
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g,)
new_grads = new_grads + (gradients[grad_idx + 1],)
@ -381,15 +323,14 @@ class THOR(Optimizer):
pooler_bias = gradients[pooler_layer_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
new_grads = new_grads + (g, pooler_bias)
# for cls1 fc layer: mlm
@ -399,15 +340,14 @@ class THOR(Optimizer):
mlm_bias = gradients[mlm_fc_idx + 1]
temp_a = self.matrix_A[matrix_idx]
temp_g = self.matrix_G[matrix_idx]
matrix_max = self.matrix_max_inv[matrix_idx]
temp_a = F.depend(temp_a, g)
temp_g = F.depend(temp_g, g)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
# add bert.cls1.output_bias grad
new_grads = new_grads + (gradients[mlm_fc_idx - 1],)
new_grads = new_grads + (g, mlm_bias)
@ -415,6 +355,7 @@ class THOR(Optimizer):
begin_idx = mlm_fc_idx + 2
end_idx = mlm_fc_idx + 4
new_grads = new_grads + gradients[begin_idx: end_idx]
lenth = len(gradients)
new_grads = new_grads + gradients[lenth - 2: lenth]
gradients = new_grads

View File

@ -14,7 +14,6 @@
# ============================================================================
"""thor_layer"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool, check_int_positive
from mindspore.common.initializer import TruncatedNormal, initializer
@ -24,7 +23,6 @@ from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P
class Embedding_Thor(Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
@ -37,7 +35,6 @@ class Embedding_Thor(Cell):
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
@ -45,11 +42,10 @@ class Embedding_Thor(Cell):
use_one_hot_embeddings=False,
initializer_range=0.02,
name='embedding_table',
is_expand=False,
batch_size=12,
damping=0.03,
loss_scale=1,
frequency=10,
frequency=100,
):
super(Embedding_Thor, self).__init__()
self.vocab_size = vocab_size
@ -59,7 +55,6 @@ class Embedding_Thor(Cell):
[vocab_size, embedding_size]),
name=name)
self.thor = True
self.is_expand = is_expand
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.GatherV2()
@ -71,13 +66,11 @@ class Embedding_Thor(Cell):
self.em_shape = tuple(embedding_shape)
self.shape = P.Shape()
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)), name='matrix_A_inv',
requires_grad=False)
self.matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float16)),
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16)),
name="matrix_G_inv", requires_grad=False)
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max = P.CusFusedAbsMax1()
self.fake_G = Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float16))
self.dampingA = Tensor(np.ones([vocab_size]).astype(np.float32))
self.dampingG = Tensor(np.identity(embedding_size), mstype.float32)
@ -117,9 +110,6 @@ class Embedding_Thor(Cell):
matrix_G = matrix_G + damping * dampingG
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G_inv
@ -127,8 +117,6 @@ class Embedding_Thor(Cell):
def construct(self, input_ids):
"""construct of Embedding_Thor"""
if self.is_expand:
input_ids = self.expand(input_ids, -1)
flat_ids = self.reshape(input_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
@ -146,6 +134,7 @@ class Embedding_Thor(Cell):
dampingA = self.cast(self.dampingA, mstype.float32)
matrix_A = matrix_A + damping * dampingA
matrix_A_inv = self.inv(matrix_A)
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
self.matrix_A_inv = matrix_A_inv
self.matrix_G_inv = self.fake_G
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
@ -156,11 +145,9 @@ class Embedding_Thor(Cell):
output = self.reshape(output_for_reshape, self.em_shape)
return output, self.embedding_table
class Dense_Thor(Cell):
"""Dense_Thor"""
# @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels'])
def __init__(self,
in_channels,
out_channels,
@ -168,7 +155,7 @@ class Dense_Thor(Cell):
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=10,
frequency=100,
has_bias=False,
activation=None,
batch_size=12):
@ -200,9 +187,6 @@ class Dense_Thor(Cell):
name='matrix_A_inv', requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros([out_channels, out_channels]).astype(np.float16)),
name="matrix_G_inv", requires_grad=False)
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max = P.CusFusedAbsMax1()
self.fake_G = Tensor(np.zeros([out_channels, out_channels]).astype(np.float16))
self.matmul = P.MatMul(transpose_b=True)
@ -250,9 +234,6 @@ class Dense_Thor(Cell):
matrix_G = matrix_G + damping * dampingG
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G_inv
@ -265,7 +246,6 @@ class Dense_Thor(Cell):
shape = self.shape(x)
normalizer = self.cast(shape[0], mstype.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
@ -273,9 +253,6 @@ class Dense_Thor(Cell):
matrix_A = matrix_A + damping * dampingA
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max(matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max(matrix_A_inv_max)
self.A_inv_max = matrix_A_inv_max
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
self.matrix_A_inv = matrix_A_inv