forked from mindspore-Ecosystem/mindspore
!22887 clearn codecheck for master and fix api error
Merge pull request !22887 from wangshuangling/master
This commit is contained in:
commit
e02819cefe
|
@ -292,6 +292,9 @@ class Conv2dThor(_ConvThor):
|
|||
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||
|
||||
Note:
|
||||
For Ascend, the type of inputs should be subclass of Tensor[Float16], Tensor[Int8].
|
||||
For GPU, the type of inputs should be subclass of Tensor[Float32].
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of the input channel :math:`C_{in}`.
|
||||
|
@ -351,7 +354,8 @@ class Conv2dThor(_ConvThor):
|
|||
|
||||
Examples:
|
||||
>>> net = nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> # for Ascend
|
||||
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float16)
|
||||
>>> print(net(x).shape)
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
|
@ -638,6 +642,7 @@ class EmbeddingThor(Cell):
|
|||
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
|
||||
return s
|
||||
|
||||
|
||||
@constexpr
|
||||
def _make_axis_range(start, end):
|
||||
axis = tuple(range(start, end))
|
||||
|
@ -804,10 +809,7 @@ class EmbeddingLookupThor(Cell):
|
|||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
else:
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.mul = P.Mul()
|
||||
self.on_value = Tensor(1.0, self.dtype)
|
||||
self.off_value = Tensor(0.0, self.dtype)
|
||||
|
|
|
@ -176,6 +176,14 @@ def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
|
|||
return matrix_a_shape, matrix_g_shape
|
||||
|
||||
|
||||
def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
|
||||
"""get layer type for dense layer and conv layer"""
|
||||
if subcell.weight.requires_grad:
|
||||
if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
|
||||
or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
|
||||
layertype_map.append(Other)
|
||||
|
||||
|
||||
def find_net_layertype_recur(net, layertype_map):
|
||||
"""get net layer type recursively."""
|
||||
cells = net.name_cells()
|
||||
|
@ -198,10 +206,7 @@ def find_net_layertype_recur(net, layertype_map):
|
|||
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
|
||||
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
|
||||
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
|
||||
if "rpn_with_loss.rpn_convs_list." in prefix.lower():
|
||||
continue
|
||||
elif subcell.weight.requires_grad:
|
||||
layertype_map.append(Other)
|
||||
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
|
||||
else:
|
||||
layertype_map.append(Other)
|
||||
else:
|
||||
|
@ -232,7 +237,6 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
|
|||
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
# for example : embedding layer
|
||||
layer_counter = layer_counter + 1
|
||||
return layer_counter
|
||||
|
||||
|
@ -740,7 +744,6 @@ class ThorAscend(Optimizer):
|
|||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
|
||||
|
||||
def _define_ascend_reducer(self, split_indices):
|
||||
"""define ascend reducer"""
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
|
@ -763,6 +766,40 @@ class ThorAscend(Optimizer):
|
|||
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
|
||||
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
|
||||
|
||||
def _get_weight_idx_map(self, layer_type, idx, weight_shape):
|
||||
"""for Ascend, get weight idx map"""
|
||||
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
|
||||
if layer_type == Embedding:
|
||||
a_pad_dim = 0
|
||||
g_pad_dim = 0
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
else:
|
||||
out_channels = weight_shape[0]
|
||||
g_pad_dim = self._get_pad_dim(out_channels)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
matrix_a_dim = weight_shape[1]
|
||||
if layer_type == Conv:
|
||||
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
|
||||
a_pad_dim = self._get_pad_dim(matrix_a_dim)
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
|
||||
self.thor_layer_count = self.thor_layer_count + 1
|
||||
if layer_type == Conv:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
|
||||
self.conv_layer_count = self.conv_layer_count + 1
|
||||
else:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
else:
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
if layer_type == LayerNorm:
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
|
||||
else:
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
|
||||
|
||||
def _process_matrix_init_and_weight_idx_map(self, net):
|
||||
"""for Ascend, process matrix init shape, and get weight idx map"""
|
||||
layer_counter = 0
|
||||
|
@ -819,39 +856,7 @@ class ThorAscend(Optimizer):
|
|||
requires_grad=False)
|
||||
self.matrix_a = self.matrix_a + (fc_matrix_a,)
|
||||
self.matrix_g = self.matrix_g + (fc_matrix_g,)
|
||||
|
||||
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
|
||||
if layer_type == Embedding:
|
||||
a_pad_dim = 0
|
||||
g_pad_dim = 0
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
else:
|
||||
out_channels = weight_shape[0]
|
||||
g_pad_dim = self._get_pad_dim(out_channels)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
matrix_a_dim = weight_shape[1]
|
||||
if layer_type == Conv:
|
||||
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
|
||||
a_pad_dim = self._get_pad_dim(matrix_a_dim)
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
|
||||
self.thor_layer_count = self.thor_layer_count + 1
|
||||
if layer_type == Conv:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
|
||||
self.conv_layer_count = self.conv_layer_count + 1
|
||||
else:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
else:
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
|
||||
if layer_type == LayerNorm:
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
|
||||
else:
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
|
||||
|
||||
self._get_weight_idx_map(layer_type, idx, weight_shape)
|
||||
# bert.cls1.output_bias: not a network layer, only a trainable param
|
||||
if "output_bias" not in self.params[idx].name.lower():
|
||||
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
|
||||
|
@ -957,6 +962,17 @@ class ThorAscend(Optimizer):
|
|||
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
|
||||
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
|
||||
|
||||
def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
|
||||
"""process conv matmul device pad"""
|
||||
if self.device_shape_pad_flag[conv_layer_count]:
|
||||
kernel_hw = weight_shape[2] * weight_shape[3]
|
||||
in_channels = weight_shape[1]
|
||||
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
|
||||
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
|
||||
(0, self.C0 - in_channels)))(matrix_a_inv)
|
||||
return matrix_a_inv
|
||||
|
||||
|
||||
def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
|
||||
matrix_a_max_allreduce, matrix_g_max_allreduce):
|
||||
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
|
||||
|
@ -1008,12 +1024,7 @@ class ThorAscend(Optimizer):
|
|||
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
|
||||
|
||||
if matmul_support_flag == 1:
|
||||
if self.device_shape_pad_flag[conv_layer_count]:
|
||||
kernel_hw = weight_shape[2] * weight_shape[3]
|
||||
in_channels = weight_shape[1]
|
||||
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
|
||||
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
|
||||
(0, self.C0 - in_channels)))(matrix_a_inv)
|
||||
matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
|
||||
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
|
||||
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
|
||||
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
|
||||
|
@ -1083,47 +1094,52 @@ class ThorAscend(Optimizer):
|
|||
g = self.cast(g, mstype.float32)
|
||||
return g
|
||||
|
||||
def _get_second_gradients_one(self, params_len, gradients, new_grads):
|
||||
"""get second gradients one"""
|
||||
for i in range(params_len):
|
||||
g = gradients[i]
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layertype_idx_map[i]
|
||||
matrix_a = self.matrix_a[thor_layer_count]
|
||||
matrix_g = self.matrix_g[thor_layer_count]
|
||||
matrix_max = self.matrix_max_inv[thor_layer_count]
|
||||
grad_shape = self.shape(g)
|
||||
if layer_type == FC:
|
||||
if grad_shape[0] == 1001:
|
||||
g = self.cube_matmul_left_fc(matrix_g, g)
|
||||
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
|
||||
else:
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
elif layer_type == Conv:
|
||||
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
|
||||
if matmul_support_flag == 1:
|
||||
g = self.cube_matmul_left(matrix_g, g)
|
||||
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
|
||||
else:
|
||||
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
g = self.reshape(g, grad_shape)
|
||||
new_grads = new_grads + (g,)
|
||||
return new_grads
|
||||
|
||||
def _get_second_gradients(self, new_grads, damping_step, gradients):
|
||||
"""get second gradients for thor"""
|
||||
params_len = len(self.params)
|
||||
if self.conv_layer_count > 0:
|
||||
for i in range(params_len):
|
||||
g = gradients[i]
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layertype_idx_map[i]
|
||||
matrix_a = self.matrix_a[thor_layer_count]
|
||||
matrix_g = self.matrix_g[thor_layer_count]
|
||||
matrix_max = self.matrix_max_inv[thor_layer_count]
|
||||
grad_shape = self.shape(g)
|
||||
if layer_type == FC:
|
||||
if grad_shape[0] == 1001:
|
||||
g = self.cube_matmul_left_fc(matrix_g, g)
|
||||
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
|
||||
else:
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
elif layer_type == Conv:
|
||||
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
|
||||
if matmul_support_flag == 1:
|
||||
g = self.cube_matmul_left(matrix_g, g)
|
||||
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
|
||||
else:
|
||||
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
g = self.reshape(g, grad_shape)
|
||||
new_grads = new_grads + (g,)
|
||||
new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
|
||||
else:
|
||||
for i in range(params_len):
|
||||
g = gradients[i]
|
||||
|
|
|
@ -54,13 +54,12 @@ def calculate_gain(nonlinearity, param=None):
|
|||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
neg_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
neg_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
raise ValueError("neg_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + neg_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
@ -89,7 +88,7 @@ def _calculate_correct_fan(tensor, mode):
|
|||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
raise ValueError("Unsupported mode {}, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
|
Loading…
Reference in New Issue