forked from mindspore-Ecosystem/mindspore
!17618 fix bug: resnet50 thor train failed in master
From: @sl_wang Reviewed-by: @guoqi1024,@kisnwang Signed-off-by: @guoqi1024
This commit is contained in:
commit
52fa09667c
|
@ -246,7 +246,7 @@ class ThorGpu(Optimizer):
|
|||
self.matrix_g = ParameterTuple(self.matrix_g)
|
||||
self.weight_decay = weight_decay
|
||||
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
|
||||
self.update_gradient = P.UpdateThorGradient(split_dim=split_dim)
|
||||
self.update_gradient = P.UpdateThorGradient(split_dim=self.split_dim)
|
||||
self.enable_clip_grad = enable_clip_grad
|
||||
self.frequency = frequency
|
||||
self._define_gpu_reducer(split_indices)
|
||||
|
@ -271,9 +271,9 @@ class ThorGpu(Optimizer):
|
|||
self.cast = P.Cast()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.eye = P.Eye()
|
||||
split_dim = 128
|
||||
self.split_dim = 128
|
||||
self.embedding_cholesky = P.CholeskyTrsm()
|
||||
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
|
||||
self.cholesky = P.CholeskyTrsm(split_dim=self.split_dim)
|
||||
self.vector_matmul = P.BatchMatMul(transpose_a=True)
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.inv = P.Reciprocal()
|
||||
|
@ -312,7 +312,7 @@ class ThorGpu(Optimizer):
|
|||
if layer_type == Conv:
|
||||
matrix_a_dim = in_channels * weight_shape[2] * weight_shape[3]
|
||||
matrix_g_dim = out_channels
|
||||
matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim)
|
||||
matrix_a_shape, matrix_g_shape = caculate_matmul_shape(matrix_a_dim, matrix_g_dim, self.split_dim)
|
||||
matrix_a_inv = Parameter(np.zeros(matrix_a_shape).astype(np.float32),
|
||||
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
|
||||
matrix_g_inv = Parameter(np.zeros(matrix_g_shape).astype(np.float32),
|
||||
|
|
|
@ -17,7 +17,7 @@ import argparse
|
|||
import numpy as np
|
||||
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.resnet_thor import resnet50 as resnet
|
||||
from src.resnet import resnet50 as resnet
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.resnet_thor import resnet50
|
||||
from src.resnet import resnet50
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == 'resnet50_thor':
|
||||
|
|
|
@ -272,7 +272,11 @@ def run_pretrain():
|
|||
accumulation_steps=accumulation_steps,
|
||||
enable_global_norm=enable_global_norm)
|
||||
else:
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, enable_clip_grad=False)
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, enable_clip_grad=True)
|
||||
if cfg.optimizer == "Thor":
|
||||
net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer, sens=cfg.Thor.loss_scale,
|
||||
enable_clip_grad=False)
|
||||
|
||||
|
||||
model = Model(net_with_grads)
|
||||
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer)
|
||||
|
|
|
@ -269,10 +269,10 @@ class BertTrainOneStepCell(nn.TrainOneStepCell):
|
|||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
sens (Number): The adjust parameter. Default: 1.0.
|
||||
enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: False.
|
||||
enable_clip_grad (boolean): If True, clip gradients in BertTrainOneStepCell. Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=False):
|
||||
def __init__(self, network, optimizer, sens=1.0, enable_clip_grad=True):
|
||||
super(BertTrainOneStepCell, self).__init__(network, optimizer, sens)
|
||||
self.cast = P.Cast()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
|
|
@ -19,7 +19,6 @@ from src.bert_model import BertModel
|
|||
from src.bert_model import BertConfig
|
||||
import mindspore.common.dtype as mstype
|
||||
bert_net_cfg = BertConfig(
|
||||
batch_size=12,
|
||||
seq_length=512,
|
||||
vocab_size=30522,
|
||||
hidden_size=1024,
|
||||
|
@ -33,11 +32,8 @@ bert_net_cfg = BertConfig(
|
|||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
use_relative_positions=False,
|
||||
input_mask_from_dataset=True,
|
||||
token_type_ids_from_dataset=True,
|
||||
dtype=mstype.float32,
|
||||
compute_type=mstype.float16,
|
||||
enable_fused_layernorm=True
|
||||
compute_type=mstype.float16
|
||||
)
|
||||
def create_network(name, *args, **kwargs):
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue