!22358 thor generalization code submit

Merge pull request !22358 from wangshuangling/master
This commit is contained in:
i-robot 2021-08-25 13:46:27 +00:00 committed by Gitee
commit 76a37daa43
12 changed files with 691 additions and 184 deletions

View File

@ -33,7 +33,7 @@ from .quant import *
from .math import *
from .combined import *
from .timedistributed import *
from .thor_layer import DenseThor, Conv2dThor, EmbeddingThor
from .thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
__all__ = []
__all__.extend(activation.__all__)

View File

@ -16,6 +16,7 @@
"""layers for second order optimization"""
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer, Initializer
from mindspore.ops import operations as P
@ -24,9 +25,14 @@ from mindspore._checkparam import Validator, Rel, twice
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
from mindspore.context import ParallelMode
from mindspore.ops.primitive import constexpr
from mindspore.ops import functional as F
from .basic import ClipByNorm
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor']
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
class DenseThor(Cell):
@ -75,6 +81,7 @@ class DenseThor(Cell):
[[ 6. 6. 6. 6.]
[ 12. 12. 12. 12. ]]
"""
def __init__(self,
in_channels,
out_channels,
@ -93,7 +100,6 @@ class DenseThor(Cell):
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
@ -106,38 +112,25 @@ class DenseThor(Cell):
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.matrix_a = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float32)),
self.matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float32)),
name='matrix_a', requires_grad=False)
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
name="matrix_g", requires_grad=False)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.mul = P.Mul()
self.is_Ascend = True
self.split_dim = 128
if context.get_context("device_target") == "Ascend":
self._process_ascend_dense_thor(out_channels)
self._process_ascend_dense_thor(out_channels, in_channels)
else:
self.is_Ascend = False
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
name="matrix_g", requires_grad=False)
self.cube_matmul = P.MatMul(transpose_a=True)
self.getG = P.InsertGradientOf(self.save_gradient)
def _process_ascend_dense_thor(self, out_channels):
def _process_ascend_dense_thor(self, out_channels, in_channels):
"""process ascend dense thor"""
if out_channels == 1001:
self.matrix_g = Parameter(Tensor(np.zeros([1024, 1024]).astype(np.float32)),
name='matrix_g', requires_grad=False)
self.pad = P.Pad(((0, 23), (0, 23)))
self.pad1 = P.Pad(((0, 7), (0, 7)))
self.slice = P.Slice()
self.add = P.TensorAdd()
else:
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
name="matrix_g", requires_grad=False)
self.abs = P.Abs()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.neg = P.Neg()
self.reduce_sum = P.ReduceSum()
self.matmul = P.MatMul(transpose_b=True)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.cast = P.Cast()
@ -155,8 +148,6 @@ class DenseThor(Cell):
normalizer = self.cast(shape[0], mstype.float32)
matrix_g = self.cube_matmul(dout, dout)
matrix_g = self.mul(matrix_g, 1.0 / normalizer)
if self.out_channels == 1001:
matrix_g = P.Pad(((0, 23), (0, 23)))(matrix_g)
self.matrix_g = matrix_g
else:
dout_shape = self.shape(dout)
@ -380,7 +371,6 @@ class Conv2dThor(_ConvThor):
stride=self.stride, dilation=self.dilation, group=self.group)
self._init_depthwise_conv2d(weight_init)
self.bias_add = P.BiasAdd()
self.thor = True
self.hw = kernel_size[0] * kernel_size[1]
self.matrix_a_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
@ -389,8 +379,8 @@ class Conv2dThor(_ConvThor):
self.reshape = P.Reshape()
self.mul = P.Mul()
self.cast = P.Cast()
self.a_normalizer = Parameter(initializer(0, [1], mstype.float32), name="a_normalizer", requires_grad=False)
self.g_normalizer = Parameter(initializer(0, [1], mstype.float32), name="g_normalizer", requires_grad=False)
self.a_normalizer = Parameter(initializer(1, [1], mstype.float32), name="a_normalizer", requires_grad=False)
self.g_normalizer = Parameter(initializer(1, [1], mstype.float32), name="g_normalizer", requires_grad=False)
self.is_Ascend = True
if context.get_context("device_target") == "Ascend":
self._process_ascend_conv2d_thor(kernel_size, stride)
@ -409,32 +399,18 @@ class Conv2dThor(_ConvThor):
"""process ascend conv2d thor"""
ksizes = (1, kernel_size[0], kernel_size[1], 1)
strides = (1, stride[0], stride[1], 1)
ksizes_tbe = (kernel_size[0], kernel_size[1])
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.transpose02314 = P.CusTranspose02314()
dampinga_dim = self.matrix_a_dim
self.diag_block_dim = 128
if (self.matrix_a_dim % self.diag_block_dim) != 0 and self.matrix_a_dim > self.diag_block_dim:
dampinga_dim = (self.matrix_a_dim // self.diag_block_dim + 1) * self.diag_block_dim
dampingg_dim = self.matrix_g_dim
if (self.matrix_g_dim % self.diag_block_dim) != 0 and self.matrix_g_dim > self.diag_block_dim:
dampingg_dim = (self.matrix_g_dim // self.diag_block_dim + 1) * self.diag_block_dim
self.matrix_a_cov = Parameter(Tensor(np.zeros([dampinga_dim, dampinga_dim]).astype(np.float32)),
self.matrix_a_cov = Parameter(Tensor(np.eye(self.matrix_a_dim).astype(np.float32)),
name='matrix_a', requires_grad=False)
self.matrix_g_cov = Parameter(Tensor(np.zeros([dampingg_dim, dampingg_dim]).astype(np.float32)),
self.matrix_g_cov = Parameter(Tensor(np.eye(self.matrix_g_dim).astype(np.float32)),
name='matrix_g', requires_grad=False)
self.channels_slice_flag = False
self.C0 = 16
if self.in_channels % self.C0 != 0:
self.channels_slice_flag = True
self.pada_flag = False
if (self.matrix_a_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_a_dim \
and self.matrix_a_dim > self.diag_block_dim:
self.pada_flag = True
pad_dim = self.diag_block_dim - self.matrix_a_dim % self.diag_block_dim
self.pada = P.Pad(((0, pad_dim), (0, pad_dim)))
self.slice = P.Slice()
self.im2col = P.NewIm2Col(ksizes=ksizes_tbe, strides=stride[0], padding_mode="SAME")
def _init_depthwise_conv2d(self, weight_init):
"""Initialize depthwise conv2d op"""
@ -460,7 +436,9 @@ class Conv2dThor(_ConvThor):
"""save_gradient"""
out = dout
if self.is_Ascend:
dout = self.transpose02314(dout)
dout_shape = self.shape(dout)
dout = self.transpose(dout, (0, 2, 3, 1))
dout = self.reshape(dout, (-1, dout_shape[1]))
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
matrix_g = self.cube_matmul(dout, dout)
@ -483,23 +461,24 @@ class Conv2dThor(_ConvThor):
def construct(self, x):
if self.thor:
matrix_a = self.img2col(x)
matrix_a_shape = self.shape(matrix_a)
if self.is_Ascend:
matrix_a = self.im2col(x)
matrix_a_shape = self.shape(matrix_a)
y = matrix_a_shape[3]
matrix_a = self.reshape(matrix_a, (-1, y))
matrix_a_shape = self.shape(matrix_a)
normalizer = matrix_a_shape[0]
matrix_a = self.cube_matmul(matrix_a, matrix_a)
if self.channels_slice_flag:
matrix_a = self.reshape(matrix_a, (self.hw, self.C0, self.hw, self.C0))
matrix_a = self.slice(matrix_a, (0, 0, 0, 0),
(self.hw, self.in_channels, self.hw, self.in_channels))
matrix_a = self.reshape(matrix_a, (self.matrix_a_dim, self.matrix_a_dim))
normalizer = self.cast(normalizer, mstype.float32)
matrix_a = self.mul(matrix_a, 1.0 / normalizer)
if self.pada_flag:
matrix_a = self.pada(matrix_a)
self.a_normalizer = normalizer
self.matrix_a_cov = matrix_a
weight = self.cast(self.weight, mstype.float16)
output = self.conv2d(x, weight)
output = self.getG(output)
else:
matrix_a = self.img2col(x)
matrix_a_shape = self.shape(matrix_a)
matrix_a = self.reshape(matrix_a, (matrix_a_shape[0] * matrix_a_shape[1] * matrix_a_shape[2],
matrix_a_shape[3], -1))
matrix_a = self.reduce_mean(matrix_a, 1)
@ -510,11 +489,19 @@ class Conv2dThor(_ConvThor):
matrix_a = self.mul(matrix_a, 1.0 / normalizer)
self.a_normalizer = normalizer
self.matrix_a_cov = matrix_a
output = self.conv2d(x, self.weight)
output = self.getG(output)
output = self.conv2d(x, self.weight)
output = self.getG(output)
else:
output = self.conv2d(x, self.weight)
if self.has_bias:
if self.is_Ascend:
weight = self.cast(self.weight, mstype.float16)
output = self.conv2d(x, weight)
else:
output = self.conv2d(x, self.weight)
if self.has_bias:
if self.is_Ascend:
bias = self.cast(self.bias, mstype.float16)
output = self.bias_add(output, bias)
else:
output = self.bias_add(output, self.bias)
return output
@ -650,3 +637,296 @@ class EmbeddingThor(Cell):
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
return s
@constexpr
def _make_axis_range(start, end):
axis = tuple(range(start, end))
return axis
class EmbeddingLookupThor(Cell):
r"""
Returns a slice of the input tensor based on the specified indices
and saving the information needed for THOR.
This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the
embeddinglookup layer needed for THOR,
the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
Args:
vocab_size (int): The size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
Refer to class `initializer` for the values of string when a string is specified.
Default: 'normal'.
target (str): Specifies the target where the op is executed. The value must in
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None.
Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true.
Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable
value to avoid insufficient memory.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Raises:
ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
'table_row_slice' or 'table_column_slice'.
ValueError: If `sparse` is False and `target` is 'CPU'.
ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
ValueError: If `vocab_size` or `embedding_size` is less than 1.
ValueError: If `vocab_cache_size` is less than 0.
Supported Platforms:
``Ascend``
Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> result = nn.EmbeddingLookup(4,2)(input_indices)
>>> print(result.shape)
(2, 2, 2)
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"
TABLE_ROW_SLICE = "table_row_slice"
TABLE_COLUMN_SLICE = "table_column_slice"
def __init__(self, vocab_size, embedding_size, param_init='normal',
target='CPU', slice_mode='batch_slice', manual_shapes=None,
max_norm=None, sparse=True, vocab_cache_size=0):
super(EmbeddingLookupThor, self).__init__()
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self.target = target
self.sparse = sparse
self.cache_enable = self.vocab_cache_size > 0
self.forward_unique = False
self.dtype = mstype.float16
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
if not sparse and target == 'CPU':
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
if sparse:
self.gatherv2 = P.SparseGatherV2()
else:
self.gatherv2 = P.Gather()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
self._process_vocab_cache(slice_mode)
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size')
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
mstype.float16), name='embedding_table')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.gather_revert = P.Gather()
self.reshape_first = P.Reshape()
self.reshape = P.Reshape()
self.unique = P.Unique()
self.shape = P.Shape()
if is_auto_parallel:
self.unique = P.Unique().shard(((1,),))
if self.cache_enable and enable_ps:
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
if is_auto_parallel:
self.unique.add_prim_attr('cache_enable', True)
indices_shape_size = 2
if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes:
raise ValueError("in slice field mode, the manual_shapes should not be none")
if not isinstance(manual_shapes, tuple):
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
for dim in manual_shapes:
Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
full_batch = _get_full_batch()
if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (get_group_size(),)))
self.forward_unique = True
indices_strategy = (1,) * indices_shape_size
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
elif slice_mode == "table_column_slice" and is_auto_parallel:
if target == 'DEVICE':
indices_shape_size = 1
self.gather_revert.shard(((1, get_group_size()), (1,)))
self.forward_unique = True
indices_strategy = (1,) * indices_shape_size
self.gatherv2.shard(((1, get_group_size()), indices_strategy))
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
elif slice_mode == "batch_slice" and is_auto_parallel:
indices_strategy = [get_group_size()]
indices_strategy.extend([1] * (indices_shape_size - 1))
indices_strategy = tuple(indices_strategy)
self.gatherv2.shard(((1, 1), indices_strategy))
self.embeddinglookup.shard(((1, 1), indices_strategy))
else:
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
if self.cache_enable and not enable_ps:
if parallel_mode != ParallelMode.STAND_ALONE:
raise ValueError("parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
self.embedding_table.unique = self.forward_unique
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = Validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float16)
self.thor = True
self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
name='matrix_a', requires_grad=False)
self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
name="matrix_g", requires_grad=False)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.getG = P.InsertGradientOf(self.save_gradient)
self.cast = P.Cast()
if context.get_context("device_target") == "Ascend":
self.cube_matmul = P.MatMul(transpose_a=True)
else:
self.cube_matmul = P.MatMul(transpose_a=True)
self.mul = P.Mul()
self.on_value = Tensor(1.0, self.dtype)
self.off_value = Tensor(0.0, self.dtype)
self.one_hot = P.OneHot()
def save_gradient(self, dout):
"""
this function only for thor optimizer
save_gradient
"""
out = dout
shape = self.shape(dout)
normalizer = self.cast(shape[0], mstype.float16)
dout = self.reshape(dout, (-1, self.embedding_size))
matrix_g = self.cube_matmul(dout, dout)
matrix_g = self.mul(matrix_g, 1.0 / normalizer)
matrix_g = self.cast(matrix_g, mstype.float16)
self.matrix_g = matrix_g
return out
def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
if self.target != 'DEVICE':
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
if not self.sparse:
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
if context.get_context("device_target") != 'Ascend':
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
self.unique.add_prim_attr('cache_enable', True)
self.embedding_table.cache_enable = self.cache_enable
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False
if self.vocab_cache_size > 0:
if self.target == 'CPU':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"current target is CPU, so it will be ignored.")
return
enable_ps = _get_ps_context("enable_ps")
if not enable_ps:
logger.warning(
"The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
return
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
rank_size = get_group_size()
rank_id = get_rank()
full_batch = _get_full_batch()
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
"in 'full_batch' and 'table_row_slice' parallel strategy.")
self.vocab_cache_size = self.vocab_cache_size * rank_size
_set_rank_id(rank_id)
self.cache_enable = True
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
if context.get_context("enable_sparse") != self.sparse:
raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
"kernels and equal the value of 'enable_sparse' in context setting in "
"parameter server cache mode")
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
_set_cache_enable(True)
if self.sparse:
self.forward_unique = True
if _is_role_worker():
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
def construct(self, indices):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)
else:
if self.thor:
if self.forward_unique:
shp = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape_first(indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
matrix_a = self.reduce_sum(one_hot_ids, 0)
matrix_a = self.cast(matrix_a, mstype.float16)
self.matrix_a = matrix_a
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
out = self.getG(weight_unique)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
out = self.reshape(weight_flatten, shp)
else:
indices_flatten = self.reshape_first(indices, (-1,))
one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
matrix_a = self.reduce_sum(one_hot_ids, 0)
matrix_a = self.cast(matrix_a, mstype.float16)
self.matrix_a = matrix_a
out = self.gatherv2(self.embedding_table, indices, 0)
out = self.getG(out)
else:
if self.forward_unique:
shp = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape_first(indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
out = self.reshape(weight_flatten, shp)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
if self.max_norm is not None:
axis = _make_axis_range(F.rank(indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)
out = clip_by_norm(out, self.max_norm)
return out

View File

@ -20,16 +20,18 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
import mindspore.log as logger
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer
Other = -1
Conv = 1
@ -120,6 +122,34 @@ def caculate_device_shape(matrix_dim, channel, is_a):
return ll
def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
"""is conv layer matmul support shape"""
temp = (matrix_g_shape, matrix_a_shape)
support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
((4, 4, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (36, 36, 16, 16)),
((16, 16, 16, 16), (4, 4, 16, 16)),
((4, 4, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (72, 72, 16, 16)),
((32, 32, 16, 16), (8, 8, 16, 16)),
((32, 32, 16, 16), (16, 16, 16, 16)),
((8, 8, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (144, 144, 16, 16)),
((64, 64, 16, 16), (16, 16, 16, 16)),
((64, 64, 16, 16), (32, 32, 16, 16)),
((16, 16, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (288, 288, 16, 16)),
((128, 128, 16, 16), (32, 32, 16, 16)),
((128, 128, 16, 16), (64, 64, 16, 16)),
((32, 32, 16, 16), (128, 128, 16, 16))]
if temp in support_shape:
return True
return False
def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
"""get matmul shape"""
split_dima = split_dim
@ -151,21 +181,29 @@ def find_net_layertype_recur(net, layertype_map):
cells = net.name_cells()
for name in cells:
subcell = cells[name]
prefix = subcell.param_prefix
if subcell == net:
continue
elif isinstance(subcell, Conv2dThor):
layertype_map.append(Conv)
elif isinstance(subcell, DenseThor):
layertype_map.append(FC)
elif isinstance(subcell, EmbeddingThor):
elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
layertype_map.append(Embedding)
elif isinstance(subcell, nn.LayerNorm):
layertype_map.append(LayerNorm)
elif isinstance(subcell, nn.BatchNorm2d):
layertype_map.append(BatchNorm)
if subcell.gamma.requires_grad:
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
layertype_map.append(Other)
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
if "rpn_with_loss.rpn_convs_list." in prefix.lower():
continue
elif subcell.weight.requires_grad:
layertype_map.append(Other)
else:
layertype_map.append(Other)
else:
find_net_layertype_recur(subcell, layertype_map)
@ -188,7 +226,14 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
if "beta" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
layer_counter = layer_counter + 1
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
elif "weight" in params[idx].name.lower():
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
# for example : embedding layer
layer_counter = layer_counter + 1
return layer_counter
@ -315,6 +360,7 @@ class ThorGpu(Optimizer):
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
@ -326,7 +372,7 @@ class ThorGpu(Optimizer):
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.damping = damping
self._define_gpu_operator()
print("matrix_a_cov len is {}".format(len(self.matrix_a_cov)), flush=True)
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self.thor = True
self.matrix_a = ()
self.matrix_g = ()
@ -375,6 +421,7 @@ class ThorGpu(Optimizer):
self.square = P.Square()
self.expand = P.ExpandDims()
def _define_gpu_reducer(self, split_indices):
"""define gpu reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
@ -602,17 +649,16 @@ class ThorAscend(Optimizer):
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
print("matrix_a_cov len is {}".format(len(self.matrix_a_cov)), flush=True)
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
self._define_ascend_operator()
self.C0 = 16
self.matrix_a_dim = ()
self.pad_a_flag = ()
self.device_shape_pad_flag = ()
self.diag_block_dim = 128
self.matrix_a = ()
@ -622,6 +668,11 @@ class ThorAscend(Optimizer):
self.weight_conv_idx_map = ()
self.weight_fim_idx_map = ()
self.weight_layertype_idx_map = ()
self.a_split_pad_dim_map = ()
self.g_split_pad_dim_map = ()
self.conv_matmul_support_map = ()
self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
self._process_matrix_init_and_weight_idx_map(self.net)
self.matrix_a = ParameterTuple(self.matrix_a)
self.matrix_g = ParameterTuple(self.matrix_g)
@ -646,6 +697,16 @@ class ThorAscend(Optimizer):
"""get thor frequency"""
return self.frequency
def _get_pad_dim(self, matrix_dim):
"""get diag split pad dim """
split_pad_dim = 0
if matrix_dim == 64:
return split_pad_dim
res = matrix_dim % self.diag_block_dim
if res != 0:
split_pad_dim = self.diag_block_dim - res
return split_pad_dim
def _define_ascend_operator(self):
"""define ascend operator"""
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
@ -663,8 +724,10 @@ class ThorAscend(Optimizer):
self.assign = P.Assign()
self.cast = P.Cast()
self.eye = P.Eye()
self.concat = P.Concat(0)
self.cholesky = P.CusCholeskyTrsm()
self.vector_matmul = P.CusBatchMatMul()
self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.matrix_combine = P.CusMatrixCombine()
self.slice = P.Slice()
@ -677,6 +740,7 @@ class ThorAscend(Optimizer):
self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
def _define_ascend_reducer(self, split_indices):
"""define ascend reducer"""
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
@ -691,13 +755,13 @@ class ThorAscend(Optimizer):
if self.conv_layer_count > 0:
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=3)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=5)
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=9)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=17)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for Ascend, process matrix init shape, and get weight idx map"""
@ -714,39 +778,66 @@ class ThorAscend(Optimizer):
matrix_g_dim = out_channels
matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
matrix_a_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
if ret:
matrix_a_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
else:
matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
self.matrix_a = self.matrix_a + (matrix_a_inv,)
self.matrix_g = self.matrix_g + (matrix_g_inv,)
self.matrix_a_dim = self.matrix_a_dim + (matrix_a_dim,)
pad_a_flag = False
if (matrix_a_dim // self.diag_block_dim) * self.diag_block_dim != matrix_a_dim \
and matrix_a_dim > self.diag_block_dim:
pad_a_flag = True
self.pad_a_flag = self.pad_a_flag + (pad_a_flag,)
device_shape_pad_flag = False
if matrix_a_dim != matrix_a_device_dim:
device_shape_pad_flag = True
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
out_channels = weight_shape[0]
if out_channels == 1001:
fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
in_channels = weight_shape[1]
if self.conv_layer_count > 0:
if out_channels == 1001:
fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
else:
fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
name='matrix_a_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
name="matrix_g_inv_" + str(self.thor_layer_count),
requires_grad=False)
self.matrix_a = self.matrix_a + (fc_matrix_a,)
self.matrix_g = self.matrix_g + (fc_matrix_g,)
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
if layer_type == Embedding:
a_pad_dim = 0
g_pad_dim = 0
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
else:
out_channels = weight_shape[0]
g_pad_dim = self._get_pad_dim(out_channels)
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
matrix_a_dim = weight_shape[1]
if layer_type == Conv:
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
a_pad_dim = self._get_pad_dim(matrix_a_dim)
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
@ -765,6 +856,39 @@ class ThorAscend(Optimizer):
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
def _process_batch_matmul(self, input_matrix):
"""process batch matmul"""
input_matrix_shape = self.shape(input_matrix)
if input_matrix_shape[0] in self.batch_matmul_support_list:
input_matrix = self.vector_matmul(input_matrix, input_matrix)
else:
input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
return input_matrix
def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
"""process cholesky pad"""
if pad_dim > 0:
matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
input_matrix = self.concat((input_matrix, matrix_sup))
return input_matrix
def _get_abs_max(self, matrix_inv, origin_dim):
"""get matrix abs max"""
cholesky_shape = self.shape(matrix_inv)
if cholesky_shape[0] in self.abs_max_support_list:
matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
matrix_max = self.fused_abs_max2(matrix_inv_max)
matrix_inv = self.matrix_combine(matrix_inv)
else:
matrix_inv = self.matrix_combine(matrix_inv)
matrix_abs = P.Abs()(matrix_inv)
matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
return matrix_max, matrix_inv
def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get fc layer ainv and ginv"""
@ -780,10 +904,14 @@ class ThorAscend(Optimizer):
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
damping = self.sqrt(damping_step)
matrix_a = matrix_a + damping * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self.vector_matmul(matrix_a_inv, matrix_a_inv)
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
weight_shape = self.shape(self.params[index])
out_channels = weight_shape[0]
in_channels = weight_shape[1]
if out_channels == 2:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = g_eye
@ -791,27 +919,13 @@ class ThorAscend(Optimizer):
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
if out_channels == 1001:
matrix_a_inv_max = self.fused_abs_max2(matrix_a_inv)
a_max = self.fused_abs_max2(matrix_a_inv_max)
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_a_inv_shape = self.shape(matrix_a_inv)
matrix_a_inv = self.reshape(matrix_a_inv,
(matrix_a_inv_shape[0] / 16, 16,
matrix_a_inv_shape[0] / 16, 16))
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv_max = P.CusFusedAbsMax1([1001, 1001])(matrix_g_inv)
g_max = self.fused_abs_max2(matrix_g_inv_max)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (1001, 1001))
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
matrix_g_inv_shape = self.shape(matrix_g_inv)
matrix_g_inv = self.reshape(matrix_g_inv,
(matrix_g_inv_shape[0] / 16, 16,
matrix_g_inv_shape[0] / 16, 16))
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
if self.conv_layer_count > 0:
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
@ -819,6 +933,26 @@ class ThorAscend(Optimizer):
else:
matrix_a_inv = self.matrix_combine(matrix_a_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
matrix_a_inv_shape = self.shape(matrix_a_inv)
matrix_g_combine_shape = self.shape(matrix_g_inv)
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
matrix_a_inv = self.reshape(matrix_a_inv,
(matrix_a_inv_shape[0] / 16, 16,
matrix_a_inv_shape[0] / 16, 16))
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
matrix_g_inv_shape = self.shape(matrix_g_inv)
matrix_g_inv = self.reshape(matrix_g_inv,
(matrix_g_inv_shape[0] / 16, 16,
matrix_g_inv_shape[0] / 16, 16))
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
@ -830,8 +964,12 @@ class ThorAscend(Optimizer):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
weight_shape = self.shape(self.params[i])
out_channels = weight_shape[0]
if layer_type == Conv:
g = gradients[i]
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
matrix_a = self.matrix_a_cov[thor_layer_count]
matrix_g = self.matrix_g_cov[thor_layer_count]
matrix_a = F.depend(matrix_a, g)
@ -848,43 +986,44 @@ class ThorAscend(Optimizer):
damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
damping_a = self.sqrt(damping_a)
matrix_a = matrix_a + damping_a * a_eye
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
matrix_a_inv = self.cholesky(matrix_a)
matrix_a_inv = self.vector_matmul(matrix_a_inv, matrix_a_inv)
a_max = P.CusFusedAbsMax1([self.matrix_a_dim[conv_layer_count],
self.matrix_a_dim[conv_layer_count]])(matrix_a_inv)
a_max = self.fused_abs_max2(a_max)
matrix_a_inv = self.matrix_combine(matrix_a_inv)
if self.pad_a_flag[conv_layer_count]:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (self.matrix_a_dim[conv_layer_count],
self.matrix_a_dim[conv_layer_count]))
if self.device_shape_pad_flag[conv_layer_count]:
weight = self.params[i]
weight_shape = self.shape(weight)
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)
damping_g = self.sqrt(damping_g)
matrix_g = self.mul(matrix_g, self.loss_scale)
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping_g * g_eye
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
g_max = self.fused_abs_max2(matrix_g_inv)
g_max = self.fused_abs_max2(g_max)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
matrix_g_inv_shape[1], matrix_g_inv_shape[3])
matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
if a_pad_dim > 0:
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
if g_pad_dim > 0:
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
if matmul_support_flag == 1:
if self.device_shape_pad_flag[conv_layer_count]:
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_a_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_a_inv)
matrix_a_inv_shape = self.shape(self.matrix_a[thor_layer_count])
matrix_a_device_temp_shape = (matrix_a_inv_shape[0], matrix_a_inv_shape[2],
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
matrix_g_inv_shape[1], matrix_g_inv_shape[3])
matrix_g_inv = self.reshape(matrix_g_inv, matrix_g_device_temp_shape)
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
a_max = F.depend(a_max, g)
g_max = F.depend(g_max, g)
@ -913,7 +1052,7 @@ class ThorAscend(Optimizer):
matrix_g = self.mul(matrix_g, self.batch_size_scale)
matrix_g = matrix_g + damping * g_eye
matrix_g_inv = self.cholesky(matrix_g)
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
matrix_g_inv = self.matrix_combine(matrix_g_inv)
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
@ -951,16 +1090,39 @@ class ThorAscend(Optimizer):
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layertype_idx_map[i]
matrix_a = self.matrix_a[thor_layer_count]
matrix_g = self.matrix_g[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
grad_shape = self.shape(g)
if layer_type == FC:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(matrix_g, g)
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
else:
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
elif layer_type == Conv:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(matrix_g, g)
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(matrix_a, mstype.float16)
temp_g = self.cast(matrix_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, matrix_max)
g = self.reshape(g, grad_shape)
new_grads = new_grads + (g,)
else:
for i in range(params_len):
@ -994,15 +1156,37 @@ class ThorAscend(Optimizer):
"""get second gradient by matmul"""
conv_layer_count = self.weight_conv_idx_map[index]
layer_type = self.weight_layertype_idx_map[index]
grad_shape = self.shape(g)
if layer_type == FC:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
if grad_shape[0] == 1001:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
elif layer_type == Conv:
a_normalizer = self.a_normalizer[conv_layer_count]
a_normalizer = F.depend(a_normalizer, g)
temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
if matmul_support_flag == 1:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
else:
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
g = self.mul(g, temp_max)
g = self.reshape(g, grad_shape)
return g, temp_max
def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):

View File

@ -63,9 +63,19 @@ def cus_matrix_combine(input_x, output, kernel_name="cus_matrix_combine"):
repeat_real = tiling_dim * matrix_dim // 64
if repeat_real <= 255:
tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8)
else:
elif repeat_real <= 510:
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_real - 255, 1, 8)
elif repeat_real <= 765:
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[510 * 64], zero, repeat_real - 510, 1, 8)
else:
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[510 * 64], zero, 255, 1, 8)
tik_instance.vector_dup(64, input_x_ub[765 * 64], zero, repeat_real - 765, 1, 8)
with tik_instance.for_range(0, tiling_dim) as j:
tik_instance.data_move(input_x_ub[j, 128 * i], input_x[i, block_index * tiling_dim + j, 0],
0, 1, 16, 0, 0)

View File

@ -25,10 +25,10 @@ new_im2col_op_info = TBERegOp("NewIm2Col") \
.attr("ksizes", "required", "listInt", "all") \
.attr("strides", "optional", "listInt", "all", "1") \
.attr("dilations", "optional", "listInt", "all", "1") \
.attr("padding_mode", "optional", "listInt", "all", "SAME") \
.attr("padding_mode", "optional", "str", "all", "SAME") \
.attr("pad_list", "optional", "listInt", "all", "0") \
.input(0, "x", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_Default) \
.dtype_format(DataType.I8_5HD, DataType.I8_Default) \
.get_op_info()

View File

@ -577,11 +577,11 @@ class NewIm2Col(PrimitiveWithInfer):
stride_w = self.strides
dilation_h = self.dilations
dilation_w = self.dilations
if self.pad_mode == "VALID":
if self.padding_mode == "VALID":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
elif self.pad_mode == "SAME":
elif self.padding_mode == "SAME":
h_out = math.ceil(x_shape[2] / stride_h)
w_out = math.ceil(x_shape[3] / stride_w)
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
@ -601,7 +601,7 @@ class NewIm2Col(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
"infer dtype"
valid_dtypes = [mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.int8]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype

View File

@ -27,7 +27,9 @@ class ConvertNetUtils:
def __init__(self):
self._convert_method_map = {nn.Dense: ConvertNetUtils._convert_dense,
nn.Embedding: ConvertNetUtils._convert_embedding,
nn.Conv2d: ConvertNetUtils._convert_conv2d}
nn.Conv2d: ConvertNetUtils._convert_conv2d,
nn.EmbeddingLookup: ConvertNetUtils._convert_embeddinglookup}
@staticmethod
def _convert_dense(subcell):
@ -63,6 +65,7 @@ class ConvertNetUtils:
new_subcell.bias = subcell.bias
return new_subcell
@staticmethod
def _convert_embedding(subcell):
"""
@ -74,6 +77,20 @@ class ConvertNetUtils:
new_subcell.embedding_table = subcell.embedding_table
return new_subcell
@staticmethod
def _convert_embeddinglookup(subcell):
"""
convert embedding cell to second_order cell
"""
new_subcell = nn.EmbeddingLookupThor(vocab_size=subcell.vocab_size,
embedding_size=subcell.embedding_size,
target=subcell.target, sparse=subcell.sparse,
vocab_cache_size=subcell.vocab_cache_size)
new_subcell.embedding_table = subcell.embedding_table
return new_subcell
@staticmethod
def _convert_conv2d(subcell):
"""
@ -87,11 +104,22 @@ class ConvertNetUtils:
pad_mode = subcell.pad_mode
has_bias = subcell.has_bias
weight = subcell.weight
new_subcell = nn.Conv2dThor(in_channel, out_channel,
kernel_size=kernel_size, stride=stride, padding=padding, pad_mode=pad_mode,
has_bias=has_bias, weight_init=weight)
return new_subcell
def _need_change(self, subcell, prefix):
"""for thor layers, need to change"""
if isinstance(subcell, (nn.Dense, nn.Conv2d)) and subcell.weight.requires_grad:
if "rpn_with_loss.rpn_convs_list." in prefix.lower() or "wide" in prefix.lower():
return False
return True
if isinstance(subcell, (nn.Embedding, nn.EmbeddingLookup)) and subcell.embedding_table.requires_grad:
return True
return False
def _convert_to_thor_net(self, net):
"""
Convert net to thor net
@ -102,23 +130,25 @@ class ConvertNetUtils:
subcell = cells[name]
if subcell == net:
continue
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor)):
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)):
continue
elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm,
nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
continue
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)):
prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.')
net.insert_child_to_cell(name, new_subcell)
change = True
if self._need_change(subcell, prefix):
new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.')
net.insert_child_to_cell(name, new_subcell)
change = True
else:
self._convert_to_thor_net(subcell)
if isinstance(net, nn.SequentialCell) and change:
net.cell_list = list(net.cells())
def convert_to_thor_net(self, net):
"""
This interface is used to convert a network to thor layer network, in order to calculate and store the
@ -174,7 +204,7 @@ class ConvertModelUtils:
Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If True, the level setting before
will be overwritten. Default: True.
will be overwritten. Default: False.
Returns:
model (Object): High-Level API for Training.
@ -195,9 +225,9 @@ class ConvertModelUtils:
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
... metrics={'acc'}, amp_level="O2",
... loss_scale_manager=loss_manager,
... keep_batchnorm_fp32=False)
... metrics={'acc'}, amp_level="O2",
... loss_scale_manager=loss_manager,
... keep_batchnorm_fp32=False)
"""
optim_name = type(optimizer).__name__

View File

@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model."""
"""High-Level API for Second Order Training or Testing.
Second-order optimizer THOR reduces the computation workload and improves the computation speed by reducing the
frequency of updating the second-order matrix. In order to optimize the overall performance, the ModelThor class
is redefined to inherit the Model class provided by MindSpore. The parameter of THOR for controlling the frequency
of updating the second-order matrix can be obtained by the ModelThor class. """
import math
from mindspore.train.callback import RunContext
from mindspore import context

View File

@ -33,8 +33,6 @@ label_smooth_factor: 0.1
lr_init: 0.05803
lr_decay: 4.04839
lr_end_epoch: 53
lars_epsilon: 0.0
lars_coefficient: 0.001
damping_init: 0.02714
damping_decay: 0.50036
frequency: 834
@ -54,8 +52,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
conv_init: "HeUniform"
dense_init: "HeUniform"
all_reduce_fusion_config:
- 85
- 160

View File

@ -33,8 +33,6 @@ label_smooth_factor: 0.1
lr_init: 0.05672
lr_decay: 4.9687
lr_end_epoch: 50
lars_epsilon: 0.0
lars_coefficient: 0.001
damping_init: 0.02345
damping_decay: 0.5467
frequency: 834
@ -54,8 +52,8 @@ enable_cache: False
cache_session_id: ""
mode_name: "GRAPH"
acc_mode: "O0"
conv_init: "XavierUniform"
dense_init: "TruncatedNormal"
conv_init: "HeUniform"
dense_init: "HeUniform"
all_reduce_fusion_config:
- 85
- 160

View File

@ -21,6 +21,7 @@ import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from src.model_utils.config import config
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
@ -208,6 +209,8 @@ class ResidualBlock(nn.Cell):
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn(out_channel)
if config.optimizer == "Thor":
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)

View File

@ -28,6 +28,7 @@ from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.nn.optim import thor
import mindspore.nn as nn
import mindspore.dataset as ds
@ -41,7 +42,7 @@ from tests.st.networks.models.resnet50.src_thor.config import config as thor_con
from tests.st.networks.models.resnet50.src_thor.dataset import create_dataset as create_dataset_thor
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor
from tests.st.networks.models.resnet50.src_thor.thor import THOR
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"
@ -268,8 +269,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004)
# optimizer
split_indices = [26, 53]
opt = THOR(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
thor_config.batch_size, split_indices=split_indices)
opt = thor(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
thor_config.batch_size, split_indices=split_indices, frequency=thor_config.frequency)
# evaluation network
dist_eval_network = ClassifyCorrectCell(net)