!22887 clearn codecheck for master and fix api error

Merge pull request !22887 from wangshuangling/master
This commit is contained in:
i-robot 2021-09-06 06:17:17 +00:00 committed by Gitee
commit e02819cefe
3 changed files with 110 additions and 93 deletions

View File

@ -292,6 +292,9 @@ class Conv2dThor(_ConvThor):
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
Note:
For Ascend, the type of inputs should be subclass of Tensor[Float16], Tensor[Int8].
For GPU, the type of inputs should be subclass of Tensor[Float32].
Args:
in_channels (int): The number of the input channel :math:`C_{in}`.
@ -351,7 +354,8 @@ class Conv2dThor(_ConvThor):
Examples:
>>> net = nn.Conv2dThor(120, 240, 4, has_bias=False, weight_init='normal')
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> # for Ascend
>>> x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float16)
>>> print(net(x).shape)
(1, 240, 1024, 640)
"""
@ -638,6 +642,7 @@ class EmbeddingThor(Cell):
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
return s
@constexpr
def _make_axis_range(start, end):
axis = tuple(range(start, end))
@ -804,10 +809,7 @@ class EmbeddingLookupThor(Cell):
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.getG = P.InsertGradientOf(self.save_gradient)
self.cast = P.Cast()
if context.get_context("device_target") == "Ascend":
self.cube_matmul = P.MatMul(transpose_a=True)
else:
self.cube_matmul = P.MatMul(transpose_a=True)
self.cube_matmul = P.MatMul(transpose_a=True)
self.mul = P.Mul()
self.on_value = Tensor(1.0, self.dtype)
self.off_value = Tensor(0.0, self.dtype)

View File

@ -176,6 +176,14 @@ def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
return matrix_a_shape, matrix_g_shape
def get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map):
"""get layer type for dense layer and conv layer"""
if subcell.weight.requires_grad:
if "rpn_with_loss.rpn_convs_list." not in prefix.lower() \
or "rpn_with_loss.rpn_convs_list.0." in prefix.lower():
layertype_map.append(Other)
def find_net_layertype_recur(net, layertype_map):
"""get net layer type recursively."""
cells = net.name_cells()
@ -198,10 +206,7 @@ def find_net_layertype_recur(net, layertype_map):
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
if "rpn_with_loss.rpn_convs_list." in prefix.lower():
continue
elif subcell.weight.requires_grad:
layertype_map.append(Other)
get_layer_type_for_dense_and_conv(subcell, prefix, layertype_map)
else:
layertype_map.append(Other)
else:
@ -232,7 +237,6 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
# for example : embedding layer
layer_counter = layer_counter + 1
return layer_counter
@ -740,7 +744,6 @@ class ThorAscend(Optimizer):
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
def _define_ascend_reducer(self, split_indices):
"""define ascend reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
@ -763,6 +766,40 @@ class ThorAscend(Optimizer):
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _get_weight_idx_map(self, layer_type, idx, weight_shape):
"""for Ascend, get weight idx map"""
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Embedding:
a_pad_dim = 0
g_pad_dim = 0
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
else:
out_channels = weight_shape[0]
g_pad_dim = self._get_pad_dim(out_channels)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
matrix_a_dim = weight_shape[1]
if layer_type == Conv:
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
a_pad_dim = self._get_pad_dim(matrix_a_dim)
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for Ascend, process matrix init shape, and get weight idx map"""
layer_counter = 0
@ -819,39 +856,7 @@ class ThorAscend(Optimizer):
requires_grad=False)
self.matrix_a = self.matrix_a + (fc_matrix_a,)
self.matrix_g = self.matrix_g + (fc_matrix_g,)
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Embedding:
a_pad_dim = 0
g_pad_dim = 0
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
else:
out_channels = weight_shape[0]
g_pad_dim = self._get_pad_dim(out_channels)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
matrix_a_dim = weight_shape[1]
if layer_type == Conv:
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
a_pad_dim = self._get_pad_dim(matrix_a_dim)
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (LayerNorm,)
else:
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (Other,)
self._get_weight_idx_map(layer_type, idx, weight_shape)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
@ -957,6 +962,17 @@ class ThorAscend(Optimizer):
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
def _process_conv_matmul_device_pad(self, conv_layer_count, weight_shape, matrix_a_inv):
"""process conv matmul device pad"""
if self.device_shape_pad_flag[conv_layer_count]:
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
return matrix_a_inv
def _get_ainv_ginv_amax_gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
@ -1008,12 +1024,7 @@ class ThorAscend(Optimizer):
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
if matmul_support_flag == 1:
if self.device_shape_pad_flag[conv_layer_count]:
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
matrix_a_inv = self._process_conv_matmul_device_pad(conv_layer_count, weight_shape, matrix_a_inv)
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
@ -1083,47 +1094,52 @@ class ThorAscend(Optimizer):
g = self.cast(g, mstype.float32)
return g
def _get_second_gradients_one(self, params_len, gradients, new_grads):
"""get second gradients one"""
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
else:
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
elif layer_type == Conv:
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
g = self.reshape(g, grad_shape)
new_grads = new_grads + (g,)
return new_grads
def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
if self.conv_layer_count > 0:
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
grad_shape = self.shape(g)
if layer_type == FC:
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
else:
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
elif layer_type == Conv:
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
g = self.reshape(g, grad_shape)
new_grads = new_grads + (g,)
new_grads = self._get_second_gradients_one(params_len, gradients, new_grads)
else:
for i in range(params_len):
g = gradients[i]

View File

@ -54,13 +54,12 @@ def calculate_gain(nonlinearity, param=None):
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
neg_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
neg_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("neg_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + neg_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
@ -89,7 +88,7 @@ def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
raise ValueError("Unsupported mode {}, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out