forked from mindspore-Ecosystem/mindspore
!22358 thor generalization code submit
Merge pull request !22358 from wangshuangling/master
This commit is contained in:
commit
76a37daa43
|
@ -33,7 +33,7 @@ from .quant import *
|
|||
from .math import *
|
||||
from .combined import *
|
||||
from .timedistributed import *
|
||||
from .thor_layer import DenseThor, Conv2dThor, EmbeddingThor
|
||||
from .thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(activation.__all__)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""layers for second order optimization"""
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import initializer, Initializer
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -24,9 +25,14 @@ from mindspore._checkparam import Validator, Rel, twice
|
|||
from mindspore import context
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import functional as F
|
||||
from .basic import ClipByNorm
|
||||
|
||||
|
||||
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor']
|
||||
__all__ = ['DenseThor', 'Conv2dThor', 'EmbeddingThor', 'EmbeddingLookupThor']
|
||||
|
||||
|
||||
class DenseThor(Cell):
|
||||
|
@ -75,6 +81,7 @@ class DenseThor(Cell):
|
|||
[[ 6. 6. 6. 6.]
|
||||
[ 12. 12. 12. 12. ]]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -93,7 +100,6 @@ class DenseThor(Cell):
|
|||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("Weight init shape error.")
|
||||
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
|
||||
|
||||
self.bias = None
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
|
@ -106,38 +112,25 @@ class DenseThor(Cell):
|
|||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
|
||||
self.matrix_a = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float32)),
|
||||
self.matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float32)),
|
||||
name='matrix_a', requires_grad=False)
|
||||
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
|
||||
name="matrix_g", requires_grad=False)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
self.mul = P.Mul()
|
||||
self.is_Ascend = True
|
||||
self.split_dim = 128
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self._process_ascend_dense_thor(out_channels)
|
||||
self._process_ascend_dense_thor(out_channels, in_channels)
|
||||
else:
|
||||
self.is_Ascend = False
|
||||
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
|
||||
name="matrix_g", requires_grad=False)
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
|
||||
def _process_ascend_dense_thor(self, out_channels):
|
||||
def _process_ascend_dense_thor(self, out_channels, in_channels):
|
||||
"""process ascend dense thor"""
|
||||
if out_channels == 1001:
|
||||
self.matrix_g = Parameter(Tensor(np.zeros([1024, 1024]).astype(np.float32)),
|
||||
name='matrix_g', requires_grad=False)
|
||||
self.pad = P.Pad(((0, 23), (0, 23)))
|
||||
self.pad1 = P.Pad(((0, 7), (0, 7)))
|
||||
self.slice = P.Slice()
|
||||
self.add = P.TensorAdd()
|
||||
else:
|
||||
self.matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
|
||||
name="matrix_g", requires_grad=False)
|
||||
self.abs = P.Abs()
|
||||
self.reduce_max = P.ReduceMax(keep_dims=False)
|
||||
self.neg = P.Neg()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.matmul = P.MatMul(transpose_b=True)
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.cast = P.Cast()
|
||||
|
@ -155,8 +148,6 @@ class DenseThor(Cell):
|
|||
normalizer = self.cast(shape[0], mstype.float32)
|
||||
matrix_g = self.cube_matmul(dout, dout)
|
||||
matrix_g = self.mul(matrix_g, 1.0 / normalizer)
|
||||
if self.out_channels == 1001:
|
||||
matrix_g = P.Pad(((0, 23), (0, 23)))(matrix_g)
|
||||
self.matrix_g = matrix_g
|
||||
else:
|
||||
dout_shape = self.shape(dout)
|
||||
|
@ -380,7 +371,6 @@ class Conv2dThor(_ConvThor):
|
|||
stride=self.stride, dilation=self.dilation, group=self.group)
|
||||
self._init_depthwise_conv2d(weight_init)
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
self.thor = True
|
||||
self.hw = kernel_size[0] * kernel_size[1]
|
||||
self.matrix_a_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
|
||||
|
@ -389,8 +379,8 @@ class Conv2dThor(_ConvThor):
|
|||
self.reshape = P.Reshape()
|
||||
self.mul = P.Mul()
|
||||
self.cast = P.Cast()
|
||||
self.a_normalizer = Parameter(initializer(0, [1], mstype.float32), name="a_normalizer", requires_grad=False)
|
||||
self.g_normalizer = Parameter(initializer(0, [1], mstype.float32), name="g_normalizer", requires_grad=False)
|
||||
self.a_normalizer = Parameter(initializer(1, [1], mstype.float32), name="a_normalizer", requires_grad=False)
|
||||
self.g_normalizer = Parameter(initializer(1, [1], mstype.float32), name="g_normalizer", requires_grad=False)
|
||||
self.is_Ascend = True
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self._process_ascend_conv2d_thor(kernel_size, stride)
|
||||
|
@ -409,32 +399,18 @@ class Conv2dThor(_ConvThor):
|
|||
"""process ascend conv2d thor"""
|
||||
ksizes = (1, kernel_size[0], kernel_size[1], 1)
|
||||
strides = (1, stride[0], stride[1], 1)
|
||||
ksizes_tbe = (kernel_size[0], kernel_size[1])
|
||||
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
|
||||
self.transpose02314 = P.CusTranspose02314()
|
||||
dampinga_dim = self.matrix_a_dim
|
||||
self.diag_block_dim = 128
|
||||
if (self.matrix_a_dim % self.diag_block_dim) != 0 and self.matrix_a_dim > self.diag_block_dim:
|
||||
dampinga_dim = (self.matrix_a_dim // self.diag_block_dim + 1) * self.diag_block_dim
|
||||
dampingg_dim = self.matrix_g_dim
|
||||
if (self.matrix_g_dim % self.diag_block_dim) != 0 and self.matrix_g_dim > self.diag_block_dim:
|
||||
dampingg_dim = (self.matrix_g_dim // self.diag_block_dim + 1) * self.diag_block_dim
|
||||
self.matrix_a_cov = Parameter(Tensor(np.zeros([dampinga_dim, dampinga_dim]).astype(np.float32)),
|
||||
self.matrix_a_cov = Parameter(Tensor(np.eye(self.matrix_a_dim).astype(np.float32)),
|
||||
name='matrix_a', requires_grad=False)
|
||||
self.matrix_g_cov = Parameter(Tensor(np.zeros([dampingg_dim, dampingg_dim]).astype(np.float32)),
|
||||
self.matrix_g_cov = Parameter(Tensor(np.eye(self.matrix_g_dim).astype(np.float32)),
|
||||
name='matrix_g', requires_grad=False)
|
||||
|
||||
self.channels_slice_flag = False
|
||||
self.C0 = 16
|
||||
if self.in_channels % self.C0 != 0:
|
||||
self.channels_slice_flag = True
|
||||
self.pada_flag = False
|
||||
if (self.matrix_a_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_a_dim \
|
||||
and self.matrix_a_dim > self.diag_block_dim:
|
||||
self.pada_flag = True
|
||||
pad_dim = self.diag_block_dim - self.matrix_a_dim % self.diag_block_dim
|
||||
self.pada = P.Pad(((0, pad_dim), (0, pad_dim)))
|
||||
self.slice = P.Slice()
|
||||
self.im2col = P.NewIm2Col(ksizes=ksizes_tbe, strides=stride[0], padding_mode="SAME")
|
||||
|
||||
def _init_depthwise_conv2d(self, weight_init):
|
||||
"""Initialize depthwise conv2d op"""
|
||||
|
@ -460,7 +436,9 @@ class Conv2dThor(_ConvThor):
|
|||
"""save_gradient"""
|
||||
out = dout
|
||||
if self.is_Ascend:
|
||||
dout = self.transpose02314(dout)
|
||||
dout_shape = self.shape(dout)
|
||||
dout = self.transpose(dout, (0, 2, 3, 1))
|
||||
dout = self.reshape(dout, (-1, dout_shape[1]))
|
||||
dout_shape = self.shape(dout)
|
||||
normalizer = dout_shape[0]
|
||||
matrix_g = self.cube_matmul(dout, dout)
|
||||
|
@ -483,23 +461,24 @@ class Conv2dThor(_ConvThor):
|
|||
|
||||
def construct(self, x):
|
||||
if self.thor:
|
||||
matrix_a = self.img2col(x)
|
||||
matrix_a_shape = self.shape(matrix_a)
|
||||
if self.is_Ascend:
|
||||
matrix_a = self.im2col(x)
|
||||
matrix_a_shape = self.shape(matrix_a)
|
||||
y = matrix_a_shape[3]
|
||||
matrix_a = self.reshape(matrix_a, (-1, y))
|
||||
matrix_a_shape = self.shape(matrix_a)
|
||||
normalizer = matrix_a_shape[0]
|
||||
matrix_a = self.cube_matmul(matrix_a, matrix_a)
|
||||
if self.channels_slice_flag:
|
||||
matrix_a = self.reshape(matrix_a, (self.hw, self.C0, self.hw, self.C0))
|
||||
matrix_a = self.slice(matrix_a, (0, 0, 0, 0),
|
||||
(self.hw, self.in_channels, self.hw, self.in_channels))
|
||||
matrix_a = self.reshape(matrix_a, (self.matrix_a_dim, self.matrix_a_dim))
|
||||
normalizer = self.cast(normalizer, mstype.float32)
|
||||
matrix_a = self.mul(matrix_a, 1.0 / normalizer)
|
||||
if self.pada_flag:
|
||||
matrix_a = self.pada(matrix_a)
|
||||
self.a_normalizer = normalizer
|
||||
self.matrix_a_cov = matrix_a
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
output = self.conv2d(x, weight)
|
||||
output = self.getG(output)
|
||||
else:
|
||||
matrix_a = self.img2col(x)
|
||||
matrix_a_shape = self.shape(matrix_a)
|
||||
matrix_a = self.reshape(matrix_a, (matrix_a_shape[0] * matrix_a_shape[1] * matrix_a_shape[2],
|
||||
matrix_a_shape[3], -1))
|
||||
matrix_a = self.reduce_mean(matrix_a, 1)
|
||||
|
@ -512,9 +491,17 @@ class Conv2dThor(_ConvThor):
|
|||
self.matrix_a_cov = matrix_a
|
||||
output = self.conv2d(x, self.weight)
|
||||
output = self.getG(output)
|
||||
else:
|
||||
if self.is_Ascend:
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
output = self.conv2d(x, weight)
|
||||
else:
|
||||
output = self.conv2d(x, self.weight)
|
||||
if self.has_bias:
|
||||
if self.is_Ascend:
|
||||
bias = self.cast(self.bias, mstype.float16)
|
||||
output = self.bias_add(output, bias)
|
||||
else:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
|
@ -650,3 +637,296 @@ class EmbeddingThor(Cell):
|
|||
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
|
||||
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
|
||||
return s
|
||||
|
||||
@constexpr
|
||||
def _make_axis_range(start, end):
|
||||
axis = tuple(range(start, end))
|
||||
return axis
|
||||
|
||||
|
||||
class EmbeddingLookupThor(Cell):
|
||||
r"""
|
||||
Returns a slice of the input tensor based on the specified indices
|
||||
and saving the information needed for THOR.
|
||||
|
||||
This module has the same function as EmbeddingLookup, but additionally saves the information A and G in the
|
||||
embeddinglookup layer needed for THOR,
|
||||
the detail can be seen in paper: https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (int): The size of the dictionary of embeddings.
|
||||
embedding_size (int): The size of each embedding vector.
|
||||
param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
|
||||
Refer to class `initializer` for the values of string when a string is specified.
|
||||
Default: 'normal'.
|
||||
target (str): Specifies the target where the op is executed. The value must in
|
||||
['DEVICE', 'CPU']. Default: 'CPU'.
|
||||
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value must get through
|
||||
nn.EmbeddingLookup. Default: nn.EmbeddingLookup.BATCH_SLICE.
|
||||
manual_shapes (tuple): The accompaniment array in field slice mode.
|
||||
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None.
|
||||
Default: None
|
||||
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true.
|
||||
Default: True.
|
||||
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
|
||||
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
|
||||
In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable
|
||||
value to avoid insufficient memory.
|
||||
|
||||
Inputs:
|
||||
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
|
||||
ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
|
||||
'table_row_slice' or 'table_column_slice'.
|
||||
ValueError: If `sparse` is False and `target` is 'CPU'.
|
||||
ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
|
||||
TypeError: If `vocab_size` or `embedding_size` or `vocab_cache_size` is not an int.
|
||||
TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
|
||||
ValueError: If `vocab_size` or `embedding_size` is less than 1.
|
||||
ValueError: If `vocab_cache_size` is less than 0.
|
||||
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
|
||||
>>> result = nn.EmbeddingLookup(4,2)(input_indices)
|
||||
>>> print(result.shape)
|
||||
(2, 2, 2)
|
||||
"""
|
||||
BATCH_SLICE = "batch_slice"
|
||||
FIELD_SLICE = "field_slice"
|
||||
TABLE_ROW_SLICE = "table_row_slice"
|
||||
TABLE_COLUMN_SLICE = "table_column_slice"
|
||||
|
||||
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
||||
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
||||
max_norm=None, sparse=True, vocab_cache_size=0):
|
||||
super(EmbeddingLookupThor, self).__init__()
|
||||
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
||||
self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size')
|
||||
self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
|
||||
self.target = target
|
||||
self.sparse = sparse
|
||||
self.cache_enable = self.vocab_cache_size > 0
|
||||
self.forward_unique = False
|
||||
self.dtype = mstype.float16
|
||||
if target not in ('CPU', 'DEVICE'):
|
||||
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
|
||||
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
|
||||
if not sparse and target == 'CPU':
|
||||
raise ValueError('When target is CPU, embedding_lookup must be sparse.')
|
||||
if sparse:
|
||||
self.gatherv2 = P.SparseGatherV2()
|
||||
else:
|
||||
self.gatherv2 = P.Gather()
|
||||
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if enable_ps:
|
||||
self._process_vocab_cache(slice_mode)
|
||||
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size')
|
||||
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
||||
mstype.float16), name='embedding_table')
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.gather_revert = P.Gather()
|
||||
self.reshape_first = P.Reshape()
|
||||
self.reshape = P.Reshape()
|
||||
self.unique = P.Unique()
|
||||
self.shape = P.Shape()
|
||||
if is_auto_parallel:
|
||||
self.unique = P.Unique().shard(((1,),))
|
||||
if self.cache_enable and enable_ps:
|
||||
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
|
||||
if is_auto_parallel:
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
indices_shape_size = 2
|
||||
if slice_mode == "field_slice" and is_auto_parallel:
|
||||
if not manual_shapes:
|
||||
raise ValueError("in slice field mode, the manual_shapes should not be none")
|
||||
if not isinstance(manual_shapes, tuple):
|
||||
raise TypeError("manual_shapes type must be tuple(int) cannot be {}!".format(type(manual_shapes)))
|
||||
for dim in manual_shapes:
|
||||
Validator.check_positive_int(dim, 'manual shape dim', self.cls_name)
|
||||
self.gatherv2.add_prim_attr("manual_split", manual_shapes)
|
||||
self.embeddinglookup.add_prim_attr("manual_split", manual_shapes)
|
||||
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
||||
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
||||
full_batch = _get_full_batch()
|
||||
if (target == 'DEVICE' and not full_batch) or (self.cache_enable and enable_ps and sparse):
|
||||
indices_shape_size = 1
|
||||
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
||||
self.forward_unique = True
|
||||
indices_strategy = (1,) * indices_shape_size
|
||||
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
|
||||
self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
|
||||
elif slice_mode == "table_column_slice" and is_auto_parallel:
|
||||
if target == 'DEVICE':
|
||||
indices_shape_size = 1
|
||||
self.gather_revert.shard(((1, get_group_size()), (1,)))
|
||||
self.forward_unique = True
|
||||
indices_strategy = (1,) * indices_shape_size
|
||||
self.gatherv2.shard(((1, get_group_size()), indices_strategy))
|
||||
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
|
||||
elif slice_mode == "batch_slice" and is_auto_parallel:
|
||||
indices_strategy = [get_group_size()]
|
||||
indices_strategy.extend([1] * (indices_shape_size - 1))
|
||||
indices_strategy = tuple(indices_strategy)
|
||||
self.gatherv2.shard(((1, 1), indices_strategy))
|
||||
self.embeddinglookup.shard(((1, 1), indices_strategy))
|
||||
else:
|
||||
if is_auto_parallel:
|
||||
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
|
||||
+ str(slice_mode))
|
||||
if self.cache_enable and not enable_ps:
|
||||
if parallel_mode != ParallelMode.STAND_ALONE:
|
||||
raise ValueError("parallel mode haven't supported cache enable yet.")
|
||||
self._set_cache_enable()
|
||||
self.embedding_table.unique = self.forward_unique
|
||||
self.max_norm = max_norm
|
||||
if self.max_norm is not None:
|
||||
self.max_norm = Validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
|
||||
self.max_norm = Tensor(self.max_norm, dtype=mstype.float16)
|
||||
|
||||
self.thor = True
|
||||
self.matrix_a = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
|
||||
name='matrix_a', requires_grad=False)
|
||||
self.matrix_g = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
|
||||
name="matrix_g", requires_grad=False)
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.getG = P.InsertGradientOf(self.save_gradient)
|
||||
self.cast = P.Cast()
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
else:
|
||||
self.cube_matmul = P.MatMul(transpose_a=True)
|
||||
self.mul = P.Mul()
|
||||
self.on_value = Tensor(1.0, self.dtype)
|
||||
self.off_value = Tensor(0.0, self.dtype)
|
||||
self.one_hot = P.OneHot()
|
||||
|
||||
|
||||
def save_gradient(self, dout):
|
||||
"""
|
||||
this function only for thor optimizer
|
||||
save_gradient
|
||||
"""
|
||||
out = dout
|
||||
shape = self.shape(dout)
|
||||
normalizer = self.cast(shape[0], mstype.float16)
|
||||
dout = self.reshape(dout, (-1, self.embedding_size))
|
||||
matrix_g = self.cube_matmul(dout, dout)
|
||||
matrix_g = self.mul(matrix_g, 1.0 / normalizer)
|
||||
matrix_g = self.cast(matrix_g, mstype.float16)
|
||||
self.matrix_g = matrix_g
|
||||
return out
|
||||
|
||||
def _set_cache_enable(self):
|
||||
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
||||
if self.target != 'DEVICE':
|
||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.")
|
||||
if not self.sparse:
|
||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.")
|
||||
if context.get_context("device_target") != 'Ascend':
|
||||
raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.")
|
||||
|
||||
logger.info("EmbeddingLookup cache enable takes effect.")
|
||||
self.forward_unique = True
|
||||
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
|
||||
self.unique.add_prim_attr('cache_enable', True)
|
||||
self.embedding_table.cache_enable = self.cache_enable
|
||||
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
|
||||
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
||||
|
||||
def _process_vocab_cache(self, slice_mode):
|
||||
"""PS embeddingLookup cache check and process."""
|
||||
self.cache_enable = False
|
||||
if self.vocab_cache_size > 0:
|
||||
if self.target == 'CPU':
|
||||
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
||||
"current target is CPU, so it will be ignored.")
|
||||
return
|
||||
enable_ps = _get_ps_context("enable_ps")
|
||||
if not enable_ps:
|
||||
logger.warning(
|
||||
"The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
|
||||
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
|
||||
return
|
||||
parallel_mode = _get_parallel_mode()
|
||||
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if is_auto_parallel:
|
||||
rank_size = get_group_size()
|
||||
rank_id = get_rank()
|
||||
full_batch = _get_full_batch()
|
||||
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
||||
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
|
||||
"in 'full_batch' and 'table_row_slice' parallel strategy.")
|
||||
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
||||
_set_rank_id(rank_id)
|
||||
self.cache_enable = True
|
||||
if _is_role_worker():
|
||||
self.vocab_size = self.vocab_cache_size
|
||||
if context.get_context("enable_sparse") != self.sparse:
|
||||
raise ValueError("The value of parameter 'sparse' must be same for all EmbeddingLookup "
|
||||
"kernels and equal the value of 'enable_sparse' in context setting in "
|
||||
"parameter server cache mode")
|
||||
|
||||
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
||||
"""PS embeddingLookup cache enable set."""
|
||||
self.embedding_table.cache_enable = True
|
||||
self.embedding_table.is_param_ps = True
|
||||
_set_cache_enable(True)
|
||||
if self.sparse:
|
||||
self.forward_unique = True
|
||||
if _is_role_worker():
|
||||
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
||||
|
||||
def construct(self, indices):
|
||||
if self.target == "CPU":
|
||||
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
||||
else:
|
||||
if self.thor:
|
||||
if self.forward_unique:
|
||||
shp = self.shape(indices) + (self.embedding_size,)
|
||||
indices_flatten = self.reshape_first(indices, (-1,))
|
||||
unique_id, unique_idx = self.unique(indices_flatten)
|
||||
one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
|
||||
matrix_a = self.reduce_sum(one_hot_ids, 0)
|
||||
matrix_a = self.cast(matrix_a, mstype.float16)
|
||||
self.matrix_a = matrix_a
|
||||
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
||||
out = self.getG(weight_unique)
|
||||
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
||||
out = self.reshape(weight_flatten, shp)
|
||||
|
||||
else:
|
||||
indices_flatten = self.reshape_first(indices, (-1,))
|
||||
one_hot_ids = self.one_hot(indices_flatten, self.vocab_size, self.on_value, self.off_value)
|
||||
matrix_a = self.reduce_sum(one_hot_ids, 0)
|
||||
matrix_a = self.cast(matrix_a, mstype.float16)
|
||||
self.matrix_a = matrix_a
|
||||
out = self.gatherv2(self.embedding_table, indices, 0)
|
||||
out = self.getG(out)
|
||||
else:
|
||||
if self.forward_unique:
|
||||
shp = self.shape(indices) + (self.embedding_size,)
|
||||
indices_flatten = self.reshape_first(indices, (-1,))
|
||||
unique_id, unique_idx = self.unique(indices_flatten)
|
||||
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
|
||||
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
|
||||
out = self.reshape(weight_flatten, shp)
|
||||
else:
|
||||
out = self.gatherv2(self.embedding_table, indices, 0)
|
||||
if self.max_norm is not None:
|
||||
axis = _make_axis_range(F.rank(indices), F.rank(out))
|
||||
clip_by_norm = ClipByNorm(axis)
|
||||
out = clip_by_norm(out, self.max_norm)
|
||||
return out
|
||||
|
|
|
@ -20,16 +20,18 @@ from mindspore.common.parameter import Parameter, ParameterTuple
|
|||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.log as logger
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.nn.optim.optimizer import Optimizer
|
||||
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
|
||||
from mindspore import context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
|
||||
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
|
||||
from mindspore.nn.wrap import DistributedGradReducer
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
|
||||
# Enumerates types of Layer
|
||||
Other = -1
|
||||
Conv = 1
|
||||
|
@ -120,6 +122,34 @@ def caculate_device_shape(matrix_dim, channel, is_a):
|
|||
return ll
|
||||
|
||||
|
||||
def is_conv_matmul_support_shape(matrix_a_shape, matrix_g_shape):
|
||||
"""is conv layer matmul support shape"""
|
||||
temp = (matrix_g_shape, matrix_a_shape)
|
||||
support_shape = [((4, 4, 16, 16), (49, 49, 16, 16)),
|
||||
((4, 4, 16, 16), (4, 4, 16, 16)),
|
||||
((4, 4, 16, 16), (36, 36, 16, 16)),
|
||||
((16, 16, 16, 16), (4, 4, 16, 16)),
|
||||
((4, 4, 16, 16), (16, 16, 16, 16)),
|
||||
((8, 8, 16, 16), (16, 16, 16, 16)),
|
||||
((8, 8, 16, 16), (72, 72, 16, 16)),
|
||||
((32, 32, 16, 16), (8, 8, 16, 16)),
|
||||
((32, 32, 16, 16), (16, 16, 16, 16)),
|
||||
((8, 8, 16, 16), (32, 32, 16, 16)),
|
||||
((16, 16, 16, 16), (32, 32, 16, 16)),
|
||||
((16, 16, 16, 16), (144, 144, 16, 16)),
|
||||
((64, 64, 16, 16), (16, 16, 16, 16)),
|
||||
((64, 64, 16, 16), (32, 32, 16, 16)),
|
||||
((16, 16, 16, 16), (64, 64, 16, 16)),
|
||||
((32, 32, 16, 16), (64, 64, 16, 16)),
|
||||
((32, 32, 16, 16), (288, 288, 16, 16)),
|
||||
((128, 128, 16, 16), (32, 32, 16, 16)),
|
||||
((128, 128, 16, 16), (64, 64, 16, 16)),
|
||||
((32, 32, 16, 16), (128, 128, 16, 16))]
|
||||
if temp in support_shape:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def caculate_matmul_shape(matrix_a_dim, matrix_g_dim, split_dim):
|
||||
"""get matmul shape"""
|
||||
split_dima = split_dim
|
||||
|
@ -151,20 +181,28 @@ def find_net_layertype_recur(net, layertype_map):
|
|||
cells = net.name_cells()
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
prefix = subcell.param_prefix
|
||||
if subcell == net:
|
||||
continue
|
||||
elif isinstance(subcell, Conv2dThor):
|
||||
layertype_map.append(Conv)
|
||||
elif isinstance(subcell, DenseThor):
|
||||
layertype_map.append(FC)
|
||||
elif isinstance(subcell, EmbeddingThor):
|
||||
elif isinstance(subcell, (EmbeddingThor, EmbeddingLookupThor)):
|
||||
layertype_map.append(Embedding)
|
||||
elif isinstance(subcell, nn.LayerNorm):
|
||||
layertype_map.append(LayerNorm)
|
||||
elif isinstance(subcell, nn.BatchNorm2d):
|
||||
if subcell.gamma.requires_grad:
|
||||
layertype_map.append(BatchNorm)
|
||||
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
|
||||
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
|
||||
if isinstance(subcell, (nn.Dense, nn.Conv2d)):
|
||||
if "rpn_with_loss.rpn_convs_list." in prefix.lower():
|
||||
continue
|
||||
elif subcell.weight.requires_grad:
|
||||
layertype_map.append(Other)
|
||||
else:
|
||||
layertype_map.append(Other)
|
||||
else:
|
||||
find_net_layertype_recur(subcell, layertype_map)
|
||||
|
@ -188,6 +226,13 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
|
|||
if "beta" in params[idx].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
if "bias" in params[idx].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
elif "weight" in params[idx].name.lower():
|
||||
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
|
||||
layer_counter = layer_counter + 1
|
||||
else:
|
||||
# for example : embedding layer
|
||||
layer_counter = layer_counter + 1
|
||||
return layer_counter
|
||||
|
||||
|
@ -315,6 +360,7 @@ class ThorGpu(Optimizer):
|
|||
self.params = self.parameters
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
|
||||
self.net = net
|
||||
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
|
||||
|
@ -326,7 +372,7 @@ class ThorGpu(Optimizer):
|
|||
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
|
||||
self.damping = damping
|
||||
self._define_gpu_operator()
|
||||
print("matrix_a_cov len is {}".format(len(self.matrix_a_cov)), flush=True)
|
||||
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
|
||||
self.thor = True
|
||||
self.matrix_a = ()
|
||||
self.matrix_g = ()
|
||||
|
@ -375,6 +421,7 @@ class ThorGpu(Optimizer):
|
|||
self.square = P.Square()
|
||||
self.expand = P.ExpandDims()
|
||||
|
||||
|
||||
def _define_gpu_reducer(self, split_indices):
|
||||
"""define gpu reducer"""
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
|
@ -602,17 +649,16 @@ class ThorAscend(Optimizer):
|
|||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.opt = P.ApplyMomentum()
|
||||
self.net = net
|
||||
self.matrix_a_cov = ParameterTuple(filter(lambda x: 'matrix_a' in x.name, net.get_parameters()))
|
||||
self.matrix_g_cov = ParameterTuple(filter(lambda x: 'matrix_g' in x.name, net.get_parameters()))
|
||||
self.a_normalizer = ParameterTuple(filter(lambda x: 'a_normalizer' in x.name, net.get_parameters()))
|
||||
self.g_normalizer = ParameterTuple(filter(lambda x: 'g_normalizer' in x.name, net.get_parameters()))
|
||||
print("matrix_a_cov len is {}".format(len(self.matrix_a_cov)), flush=True)
|
||||
logger.info("matrix_a_cov len is {}".format(len(self.matrix_a_cov)))
|
||||
self._define_ascend_operator()
|
||||
self.C0 = 16
|
||||
self.matrix_a_dim = ()
|
||||
self.pad_a_flag = ()
|
||||
self.device_shape_pad_flag = ()
|
||||
self.diag_block_dim = 128
|
||||
self.matrix_a = ()
|
||||
|
@ -622,6 +668,11 @@ class ThorAscend(Optimizer):
|
|||
self.weight_conv_idx_map = ()
|
||||
self.weight_fim_idx_map = ()
|
||||
self.weight_layertype_idx_map = ()
|
||||
self.a_split_pad_dim_map = ()
|
||||
self.g_split_pad_dim_map = ()
|
||||
self.conv_matmul_support_map = ()
|
||||
self.batch_matmul_support_list = [1, 2, 4, 5, 6, 8, 9, 16, 18, 24, 32, 36]
|
||||
self.abs_max_support_list = [1, 2, 4, 8, 16, 5, 9, 18, 36, 32]
|
||||
self._process_matrix_init_and_weight_idx_map(self.net)
|
||||
self.matrix_a = ParameterTuple(self.matrix_a)
|
||||
self.matrix_g = ParameterTuple(self.matrix_g)
|
||||
|
@ -646,6 +697,16 @@ class ThorAscend(Optimizer):
|
|||
"""get thor frequency"""
|
||||
return self.frequency
|
||||
|
||||
def _get_pad_dim(self, matrix_dim):
|
||||
"""get diag split pad dim """
|
||||
split_pad_dim = 0
|
||||
if matrix_dim == 64:
|
||||
return split_pad_dim
|
||||
res = matrix_dim % self.diag_block_dim
|
||||
if res != 0:
|
||||
split_pad_dim = self.diag_block_dim - res
|
||||
return split_pad_dim
|
||||
|
||||
def _define_ascend_operator(self):
|
||||
"""define ascend operator"""
|
||||
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
|
||||
|
@ -663,8 +724,10 @@ class ThorAscend(Optimizer):
|
|||
self.assign = P.Assign()
|
||||
self.cast = P.Cast()
|
||||
self.eye = P.Eye()
|
||||
self.concat = P.Concat(0)
|
||||
self.cholesky = P.CusCholeskyTrsm()
|
||||
self.vector_matmul = P.CusBatchMatMul()
|
||||
self.tbe_batch_matmul = P.BatchMatMul(transpose_a=True)
|
||||
self.fused_abs_max2 = P.CusFusedAbsMax1()
|
||||
self.matrix_combine = P.CusMatrixCombine()
|
||||
self.slice = P.Slice()
|
||||
|
@ -677,6 +740,7 @@ class ThorAscend(Optimizer):
|
|||
self.one = Tensor(1, mstype.int32)
|
||||
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
|
||||
|
||||
|
||||
def _define_ascend_reducer(self, split_indices):
|
||||
"""define ascend reducer"""
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
|
@ -691,13 +755,13 @@ class ThorAscend(Optimizer):
|
|||
if self.conv_layer_count > 0:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
|
||||
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
|
||||
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
|
||||
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=3)
|
||||
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=5)
|
||||
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
|
||||
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
|
||||
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
|
||||
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=9)
|
||||
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=17)
|
||||
|
||||
def _process_matrix_init_and_weight_idx_map(self, net):
|
||||
"""for Ascend, process matrix init shape, and get weight idx map"""
|
||||
|
@ -714,26 +778,31 @@ class ThorAscend(Optimizer):
|
|||
matrix_g_dim = out_channels
|
||||
matrix_a_device_shape, matrix_a_device_dim = caculate_device_shape(matrix_a_dim, in_channels, True)
|
||||
matrix_g_device_shape, matrix_g_device_dim = caculate_device_shape(matrix_g_dim, in_channels, False)
|
||||
ret = is_conv_matmul_support_shape(matrix_a_device_shape, matrix_g_device_shape)
|
||||
if ret:
|
||||
matrix_a_inv = Parameter(
|
||||
Tensor(np.reshape(np.identity(matrix_a_device_dim).astype(np.float16), matrix_a_device_shape)),
|
||||
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
|
||||
matrix_g_inv = Parameter(
|
||||
Tensor(np.reshape(np.identity(matrix_g_device_dim).astype(np.float16), matrix_g_device_shape)),
|
||||
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
|
||||
self.conv_matmul_support_map = self.conv_matmul_support_map + (1,)
|
||||
else:
|
||||
matrix_a_inv = Parameter(Tensor(np.eye(matrix_a_dim).astype(np.float16)),
|
||||
name='matrix_a_inv_' + str(self.thor_layer_count), requires_grad=False)
|
||||
matrix_g_inv = Parameter(Tensor(np.eye(matrix_g_dim).astype(np.float16)),
|
||||
name="matrix_g_inv_" + str(self.thor_layer_count), requires_grad=False)
|
||||
self.conv_matmul_support_map = self.conv_matmul_support_map + (0,)
|
||||
self.matrix_a = self.matrix_a + (matrix_a_inv,)
|
||||
self.matrix_g = self.matrix_g + (matrix_g_inv,)
|
||||
self.matrix_a_dim = self.matrix_a_dim + (matrix_a_dim,)
|
||||
pad_a_flag = False
|
||||
if (matrix_a_dim // self.diag_block_dim) * self.diag_block_dim != matrix_a_dim \
|
||||
and matrix_a_dim > self.diag_block_dim:
|
||||
pad_a_flag = True
|
||||
self.pad_a_flag = self.pad_a_flag + (pad_a_flag,)
|
||||
device_shape_pad_flag = False
|
||||
if matrix_a_dim != matrix_a_device_dim:
|
||||
device_shape_pad_flag = True
|
||||
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
|
||||
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
|
||||
out_channels = weight_shape[0]
|
||||
in_channels = weight_shape[1]
|
||||
if self.conv_layer_count > 0:
|
||||
if out_channels == 1001:
|
||||
fc_matrix_a = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
|
||||
name='matrix_a_inv_' + str(self.thor_layer_count),
|
||||
|
@ -741,12 +810,34 @@ class ThorAscend(Optimizer):
|
|||
fc_matrix_g = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
|
||||
name="matrix_g_inv_" + str(self.thor_layer_count),
|
||||
requires_grad=False)
|
||||
else:
|
||||
fc_matrix_a = Parameter(Tensor(np.eye(in_channels).astype(np.float16)),
|
||||
name='matrix_a_inv_' + str(self.thor_layer_count),
|
||||
requires_grad=False)
|
||||
fc_matrix_g = Parameter(Tensor(np.eye(out_channels).astype(np.float16)),
|
||||
name="matrix_g_inv_" + str(self.thor_layer_count),
|
||||
requires_grad=False)
|
||||
self.matrix_a = self.matrix_a + (fc_matrix_a,)
|
||||
self.matrix_g = self.matrix_g + (fc_matrix_g,)
|
||||
|
||||
if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
|
||||
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
|
||||
self.weight_layertype_idx_map = self.weight_layertype_idx_map + (layer_type,)
|
||||
if layer_type == Embedding:
|
||||
a_pad_dim = 0
|
||||
g_pad_dim = 0
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
else:
|
||||
out_channels = weight_shape[0]
|
||||
g_pad_dim = self._get_pad_dim(out_channels)
|
||||
self.g_split_pad_dim_map = self.g_split_pad_dim_map + (g_pad_dim,)
|
||||
matrix_a_dim = weight_shape[1]
|
||||
if layer_type == Conv:
|
||||
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
|
||||
a_pad_dim = self._get_pad_dim(matrix_a_dim)
|
||||
self.a_split_pad_dim_map = self.a_split_pad_dim_map + (a_pad_dim,)
|
||||
|
||||
self.thor_layer_count = self.thor_layer_count + 1
|
||||
if layer_type == Conv:
|
||||
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
|
||||
|
@ -765,6 +856,39 @@ class ThorAscend(Optimizer):
|
|||
if "output_bias" not in self.params[idx].name.lower():
|
||||
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)
|
||||
|
||||
def _process_batch_matmul(self, input_matrix):
|
||||
"""process batch matmul"""
|
||||
input_matrix_shape = self.shape(input_matrix)
|
||||
if input_matrix_shape[0] in self.batch_matmul_support_list:
|
||||
input_matrix = self.vector_matmul(input_matrix, input_matrix)
|
||||
else:
|
||||
input_matrix = self.tbe_batch_matmul(input_matrix, input_matrix)
|
||||
return input_matrix
|
||||
|
||||
def _process_cholesky_pad(self, pad_dim, input_matrix, matrix_shape0):
|
||||
"""process cholesky pad"""
|
||||
if pad_dim > 0:
|
||||
matrix_sup = self.eye(pad_dim, pad_dim, mstype.float32)
|
||||
matrix_sup = P.Pad(((0, 0), (matrix_shape0, 0)))(matrix_sup)
|
||||
input_matrix = P.Pad(((0, 0), (0, pad_dim)))(input_matrix)
|
||||
input_matrix = self.concat((input_matrix, matrix_sup))
|
||||
return input_matrix
|
||||
|
||||
|
||||
def _get_abs_max(self, matrix_inv, origin_dim):
|
||||
"""get matrix abs max"""
|
||||
cholesky_shape = self.shape(matrix_inv)
|
||||
if cholesky_shape[0] in self.abs_max_support_list:
|
||||
matrix_inv_max = P.CusFusedAbsMax1([origin_dim, origin_dim])(matrix_inv)
|
||||
matrix_max = self.fused_abs_max2(matrix_inv_max)
|
||||
matrix_inv = self.matrix_combine(matrix_inv)
|
||||
else:
|
||||
matrix_inv = self.matrix_combine(matrix_inv)
|
||||
matrix_abs = P.Abs()(matrix_inv)
|
||||
matrix_max = P.ReduceMax(keep_dims=False)(matrix_abs)
|
||||
return matrix_max, matrix_inv
|
||||
|
||||
|
||||
def _get_fc_ainv_ginv(self, index, damping_step, gradients, matrix_a_allreduce, matrix_g_allreduce,
|
||||
matrix_a_max_allreduce, matrix_g_max_allreduce):
|
||||
"""get fc layer ainv and ginv"""
|
||||
|
@ -780,10 +904,14 @@ class ThorAscend(Optimizer):
|
|||
g_eye = self.eye(g_shape[0], g_shape[0], mstype.float32)
|
||||
damping = self.sqrt(damping_step)
|
||||
matrix_a = matrix_a + damping * a_eye
|
||||
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
|
||||
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
|
||||
matrix_a_inv = self.cholesky(matrix_a)
|
||||
matrix_a_inv = self.vector_matmul(matrix_a_inv, matrix_a_inv)
|
||||
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
|
||||
|
||||
weight_shape = self.shape(self.params[index])
|
||||
out_channels = weight_shape[0]
|
||||
in_channels = weight_shape[1]
|
||||
if out_channels == 2:
|
||||
matrix_a_inv = self.matrix_combine(matrix_a_inv)
|
||||
matrix_g_inv = g_eye
|
||||
|
@ -791,27 +919,13 @@ class ThorAscend(Optimizer):
|
|||
matrix_g = self.mul(matrix_g, self.loss_scale)
|
||||
matrix_g = self.mul(matrix_g, self.batch_size_scale)
|
||||
matrix_g = matrix_g + damping * g_eye
|
||||
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
|
||||
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
|
||||
matrix_g_inv = self.cholesky(matrix_g)
|
||||
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
|
||||
if out_channels == 1001:
|
||||
matrix_a_inv_max = self.fused_abs_max2(matrix_a_inv)
|
||||
a_max = self.fused_abs_max2(matrix_a_inv_max)
|
||||
matrix_a_inv = self.matrix_combine(matrix_a_inv)
|
||||
matrix_a_inv_shape = self.shape(matrix_a_inv)
|
||||
matrix_a_inv = self.reshape(matrix_a_inv,
|
||||
(matrix_a_inv_shape[0] / 16, 16,
|
||||
matrix_a_inv_shape[0] / 16, 16))
|
||||
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
|
||||
matrix_g_inv_max = P.CusFusedAbsMax1([1001, 1001])(matrix_g_inv)
|
||||
g_max = self.fused_abs_max2(matrix_g_inv_max)
|
||||
matrix_g_inv = self.matrix_combine(matrix_g_inv)
|
||||
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (1001, 1001))
|
||||
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
|
||||
matrix_g_inv_shape = self.shape(matrix_g_inv)
|
||||
matrix_g_inv = self.reshape(matrix_g_inv,
|
||||
(matrix_g_inv_shape[0] / 16, 16,
|
||||
matrix_g_inv_shape[0] / 16, 16))
|
||||
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
|
||||
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
|
||||
if self.conv_layer_count > 0:
|
||||
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, in_channels)
|
||||
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
|
||||
a_max = F.depend(a_max, g)
|
||||
g_max = F.depend(g_max, g)
|
||||
matrix_a_max_allreduce = matrix_a_max_allreduce + (a_max,)
|
||||
|
@ -819,6 +933,26 @@ class ThorAscend(Optimizer):
|
|||
else:
|
||||
matrix_a_inv = self.matrix_combine(matrix_a_inv)
|
||||
matrix_g_inv = self.matrix_combine(matrix_g_inv)
|
||||
|
||||
if a_pad_dim > 0:
|
||||
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (in_channels, in_channels))
|
||||
if g_pad_dim > 0:
|
||||
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
|
||||
matrix_a_inv_shape = self.shape(matrix_a_inv)
|
||||
matrix_g_combine_shape = self.shape(matrix_g_inv)
|
||||
if matrix_a_inv_shape[0] == 2048 and matrix_g_combine_shape[0] == 1001:
|
||||
matrix_a_inv = self.reshape(matrix_a_inv,
|
||||
(matrix_a_inv_shape[0] / 16, 16,
|
||||
matrix_a_inv_shape[0] / 16, 16))
|
||||
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
|
||||
matrix_g_inv = P.Pad(((0, 7), (0, 7)))(matrix_g_inv)
|
||||
|
||||
matrix_g_inv_shape = self.shape(matrix_g_inv)
|
||||
matrix_g_inv = self.reshape(matrix_g_inv,
|
||||
(matrix_g_inv_shape[0] / 16, 16,
|
||||
matrix_g_inv_shape[0] / 16, 16))
|
||||
matrix_g_inv = self.transpose(matrix_g_inv, (2, 0, 1, 3))
|
||||
|
||||
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
|
||||
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
|
||||
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce
|
||||
|
@ -830,8 +964,12 @@ class ThorAscend(Optimizer):
|
|||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layertype_idx_map[i]
|
||||
weight_shape = self.shape(self.params[i])
|
||||
out_channels = weight_shape[0]
|
||||
if layer_type == Conv:
|
||||
g = gradients[i]
|
||||
matrix_a_dim = weight_shape[1] * weight_shape[2] * weight_shape[3]
|
||||
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
|
||||
matrix_a = self.matrix_a_cov[thor_layer_count]
|
||||
matrix_g = self.matrix_g_cov[thor_layer_count]
|
||||
matrix_a = F.depend(matrix_a, g)
|
||||
|
@ -848,18 +986,29 @@ class ThorAscend(Optimizer):
|
|||
damping_g = self.mul(damping_step, self.batch_size / g_normalizer)
|
||||
damping_a = self.sqrt(damping_a)
|
||||
matrix_a = matrix_a + damping_a * a_eye
|
||||
a_pad_dim = self.a_split_pad_dim_map[thor_layer_count]
|
||||
matrix_a = self._process_cholesky_pad(a_pad_dim, matrix_a, a_shape[0])
|
||||
matrix_a_inv = self.cholesky(matrix_a)
|
||||
matrix_a_inv = self.vector_matmul(matrix_a_inv, matrix_a_inv)
|
||||
a_max = P.CusFusedAbsMax1([self.matrix_a_dim[conv_layer_count],
|
||||
self.matrix_a_dim[conv_layer_count]])(matrix_a_inv)
|
||||
a_max = self.fused_abs_max2(a_max)
|
||||
matrix_a_inv = self.matrix_combine(matrix_a_inv)
|
||||
if self.pad_a_flag[conv_layer_count]:
|
||||
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (self.matrix_a_dim[conv_layer_count],
|
||||
self.matrix_a_dim[conv_layer_count]))
|
||||
matrix_a_inv = self._process_batch_matmul(matrix_a_inv)
|
||||
a_max, matrix_a_inv = self._get_abs_max(matrix_a_inv, matrix_a_dim)
|
||||
|
||||
damping_g = self.sqrt(damping_g)
|
||||
matrix_g = self.mul(matrix_g, self.loss_scale)
|
||||
matrix_g = self.mul(matrix_g, self.batch_size_scale)
|
||||
matrix_g = matrix_g + damping_g * g_eye
|
||||
g_pad_dim = self.g_split_pad_dim_map[thor_layer_count]
|
||||
matrix_g = self._process_cholesky_pad(g_pad_dim, matrix_g, g_shape[0])
|
||||
matrix_g_inv = self.cholesky(matrix_g)
|
||||
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
|
||||
g_max, matrix_g_inv = self._get_abs_max(matrix_g_inv, out_channels)
|
||||
|
||||
if a_pad_dim > 0:
|
||||
matrix_a_inv = self.slice(matrix_a_inv, (0, 0), (matrix_a_dim, matrix_a_dim))
|
||||
if g_pad_dim > 0:
|
||||
matrix_g_inv = self.slice(matrix_g_inv, (0, 0), (out_channels, out_channels))
|
||||
|
||||
if matmul_support_flag == 1:
|
||||
if self.device_shape_pad_flag[conv_layer_count]:
|
||||
weight = self.params[i]
|
||||
weight_shape = self.shape(weight)
|
||||
kernel_hw = weight_shape[2] * weight_shape[3]
|
||||
in_channels = weight_shape[1]
|
||||
matrix_a_inv = self.reshape(matrix_a_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
|
||||
|
@ -870,16 +1019,6 @@ class ThorAscend(Optimizer):
|
|||
matrix_a_inv_shape[1], matrix_a_inv_shape[3])
|
||||
matrix_a_inv = self.reshape(matrix_a_inv, matrix_a_device_temp_shape)
|
||||
matrix_a_inv = self.transpose(matrix_a_inv, (2, 0, 1, 3))
|
||||
|
||||
damping_g = self.sqrt(damping_g)
|
||||
matrix_g = self.mul(matrix_g, self.loss_scale)
|
||||
matrix_g = self.mul(matrix_g, self.batch_size_scale)
|
||||
matrix_g = matrix_g + damping_g * g_eye
|
||||
matrix_g_inv = self.cholesky(matrix_g)
|
||||
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
|
||||
g_max = self.fused_abs_max2(matrix_g_inv)
|
||||
g_max = self.fused_abs_max2(g_max)
|
||||
matrix_g_inv = self.matrix_combine(matrix_g_inv)
|
||||
matrix_g_inv_shape = self.shape(self.matrix_g[thor_layer_count])
|
||||
matrix_g_device_temp_shape = (matrix_g_inv_shape[0], matrix_g_inv_shape[2],
|
||||
matrix_g_inv_shape[1], matrix_g_inv_shape[3])
|
||||
|
@ -913,7 +1052,7 @@ class ThorAscend(Optimizer):
|
|||
matrix_g = self.mul(matrix_g, self.batch_size_scale)
|
||||
matrix_g = matrix_g + damping * g_eye
|
||||
matrix_g_inv = self.cholesky(matrix_g)
|
||||
matrix_g_inv = self.vector_matmul(matrix_g_inv, matrix_g_inv)
|
||||
matrix_g_inv = self._process_batch_matmul(matrix_g_inv)
|
||||
matrix_g_inv = self.matrix_combine(matrix_g_inv)
|
||||
matrix_a_allreduce = matrix_a_allreduce + (matrix_a_inv,)
|
||||
matrix_g_allreduce = matrix_g_allreduce + (matrix_g_inv,)
|
||||
|
@ -951,16 +1090,39 @@ class ThorAscend(Optimizer):
|
|||
for i in range(params_len):
|
||||
g = gradients[i]
|
||||
thor_layer_count = self.weight_fim_idx_map[i]
|
||||
conv_layer_count = self.weight_conv_idx_map[i]
|
||||
layer_type = self.weight_layertype_idx_map[i]
|
||||
matrix_a = self.matrix_a[thor_layer_count]
|
||||
matrix_g = self.matrix_g[thor_layer_count]
|
||||
matrix_max = self.matrix_max_inv[thor_layer_count]
|
||||
grad_shape = self.shape(g)
|
||||
if layer_type == FC:
|
||||
if grad_shape[0] == 1001:
|
||||
g = self.cube_matmul_left_fc(matrix_g, g)
|
||||
g = self.cube_matmul_right_fc(g, matrix_a, matrix_max)
|
||||
else:
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
elif layer_type == Conv:
|
||||
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
|
||||
if matmul_support_flag == 1:
|
||||
g = self.cube_matmul_left(matrix_g, g)
|
||||
g = self.cube_matmul_right_mul(g, matrix_a, matrix_max)
|
||||
else:
|
||||
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
|
||||
temp_a = self.cast(matrix_a, mstype.float16)
|
||||
temp_g = self.cast(matrix_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, matrix_max)
|
||||
g = self.reshape(g, grad_shape)
|
||||
new_grads = new_grads + (g,)
|
||||
else:
|
||||
for i in range(params_len):
|
||||
|
@ -994,15 +1156,37 @@ class ThorAscend(Optimizer):
|
|||
"""get second gradient by matmul"""
|
||||
conv_layer_count = self.weight_conv_idx_map[index]
|
||||
layer_type = self.weight_layertype_idx_map[index]
|
||||
grad_shape = self.shape(g)
|
||||
if layer_type == FC:
|
||||
if grad_shape[0] == 1001:
|
||||
g = self.cube_matmul_left_fc(temp_g, g)
|
||||
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
|
||||
else:
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
elif layer_type == Conv:
|
||||
a_normalizer = self.a_normalizer[conv_layer_count]
|
||||
a_normalizer = F.depend(a_normalizer, g)
|
||||
temp_max = self.mul(temp_max, self.batch_size / a_normalizer)
|
||||
matmul_support_flag = self.conv_matmul_support_map[conv_layer_count]
|
||||
if matmul_support_flag == 1:
|
||||
g = self.cube_matmul_left(temp_g, g)
|
||||
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
|
||||
else:
|
||||
g = self.reshape(g, (grad_shape[0], grad_shape[1] * grad_shape[2] * grad_shape[3]))
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
g = self.matmul(temp_g, g)
|
||||
g = self.matmul(g, temp_a)
|
||||
g = self.cast(g, mstype.float32)
|
||||
g = self.mul(g, temp_max)
|
||||
g = self.reshape(g, grad_shape)
|
||||
return g, temp_max
|
||||
|
||||
def _get_second_grad_by_layertype(self, index, matrix_a_allreduce, matrix_g_allreduce, g, damping_step):
|
||||
|
|
|
@ -63,9 +63,19 @@ def cus_matrix_combine(input_x, output, kernel_name="cus_matrix_combine"):
|
|||
repeat_real = tiling_dim * matrix_dim // 64
|
||||
if repeat_real <= 255:
|
||||
tik_instance.vector_dup(64, input_x_ub, zero, repeat_real, 1, 8)
|
||||
else:
|
||||
elif repeat_real <= 510:
|
||||
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, repeat_real - 255, 1, 8)
|
||||
elif repeat_real <= 765:
|
||||
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[510 * 64], zero, repeat_real - 510, 1, 8)
|
||||
else:
|
||||
tik_instance.vector_dup(64, input_x_ub, zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[255 * 64], zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[510 * 64], zero, 255, 1, 8)
|
||||
tik_instance.vector_dup(64, input_x_ub[765 * 64], zero, repeat_real - 765, 1, 8)
|
||||
|
||||
with tik_instance.for_range(0, tiling_dim) as j:
|
||||
tik_instance.data_move(input_x_ub[j, 128 * i], input_x[i, block_index * tiling_dim + j, 0],
|
||||
0, 1, 16, 0, 0)
|
||||
|
|
|
@ -25,10 +25,10 @@ new_im2col_op_info = TBERegOp("NewIm2Col") \
|
|||
.attr("ksizes", "required", "listInt", "all") \
|
||||
.attr("strides", "optional", "listInt", "all", "1") \
|
||||
.attr("dilations", "optional", "listInt", "all", "1") \
|
||||
.attr("padding_mode", "optional", "listInt", "all", "SAME") \
|
||||
.attr("padding_mode", "optional", "str", "all", "SAME") \
|
||||
.attr("pad_list", "optional", "listInt", "all", "0") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I8_5HD, DataType.I8_Default) \
|
||||
.get_op_info()
|
||||
|
|
|
@ -577,11 +577,11 @@ class NewIm2Col(PrimitiveWithInfer):
|
|||
stride_w = self.strides
|
||||
dilation_h = self.dilations
|
||||
dilation_w = self.dilations
|
||||
if self.pad_mode == "VALID":
|
||||
if self.padding_mode == "VALID":
|
||||
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
|
||||
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
|
||||
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
|
||||
elif self.pad_mode == "SAME":
|
||||
elif self.padding_mode == "SAME":
|
||||
h_out = math.ceil(x_shape[2] / stride_h)
|
||||
w_out = math.ceil(x_shape[3] / stride_w)
|
||||
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
|
||||
|
@ -601,7 +601,7 @@ class NewIm2Col(PrimitiveWithInfer):
|
|||
|
||||
def infer_dtype(self, x_dtype):
|
||||
"infer dtype"
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
valid_dtypes = [mstype.float16, mstype.int8]
|
||||
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
|
|
|
@ -27,7 +27,9 @@ class ConvertNetUtils:
|
|||
def __init__(self):
|
||||
self._convert_method_map = {nn.Dense: ConvertNetUtils._convert_dense,
|
||||
nn.Embedding: ConvertNetUtils._convert_embedding,
|
||||
nn.Conv2d: ConvertNetUtils._convert_conv2d}
|
||||
nn.Conv2d: ConvertNetUtils._convert_conv2d,
|
||||
nn.EmbeddingLookup: ConvertNetUtils._convert_embeddinglookup}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_dense(subcell):
|
||||
|
@ -63,6 +65,7 @@ class ConvertNetUtils:
|
|||
new_subcell.bias = subcell.bias
|
||||
return new_subcell
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_embedding(subcell):
|
||||
"""
|
||||
|
@ -74,6 +77,20 @@ class ConvertNetUtils:
|
|||
new_subcell.embedding_table = subcell.embedding_table
|
||||
return new_subcell
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_embeddinglookup(subcell):
|
||||
"""
|
||||
convert embedding cell to second_order cell
|
||||
"""
|
||||
new_subcell = nn.EmbeddingLookupThor(vocab_size=subcell.vocab_size,
|
||||
embedding_size=subcell.embedding_size,
|
||||
target=subcell.target, sparse=subcell.sparse,
|
||||
vocab_cache_size=subcell.vocab_cache_size)
|
||||
new_subcell.embedding_table = subcell.embedding_table
|
||||
return new_subcell
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_conv2d(subcell):
|
||||
"""
|
||||
|
@ -87,11 +104,22 @@ class ConvertNetUtils:
|
|||
pad_mode = subcell.pad_mode
|
||||
has_bias = subcell.has_bias
|
||||
weight = subcell.weight
|
||||
|
||||
new_subcell = nn.Conv2dThor(in_channel, out_channel,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding, pad_mode=pad_mode,
|
||||
has_bias=has_bias, weight_init=weight)
|
||||
return new_subcell
|
||||
|
||||
def _need_change(self, subcell, prefix):
|
||||
"""for thor layers, need to change"""
|
||||
if isinstance(subcell, (nn.Dense, nn.Conv2d)) and subcell.weight.requires_grad:
|
||||
if "rpn_with_loss.rpn_convs_list." in prefix.lower() or "wide" in prefix.lower():
|
||||
return False
|
||||
return True
|
||||
if isinstance(subcell, (nn.Embedding, nn.EmbeddingLookup)) and subcell.embedding_table.requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_thor_net(self, net):
|
||||
"""
|
||||
Convert net to thor net
|
||||
|
@ -102,13 +130,14 @@ class ConvertNetUtils:
|
|||
subcell = cells[name]
|
||||
if subcell == net:
|
||||
continue
|
||||
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor)):
|
||||
elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)):
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm,
|
||||
nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
|
||||
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)):
|
||||
prefix = subcell.param_prefix
|
||||
if self._need_change(subcell, prefix):
|
||||
new_subcell = self._convert_method_map[type(subcell)](subcell)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
net.insert_child_to_cell(name, new_subcell)
|
||||
|
@ -119,6 +148,7 @@ class ConvertNetUtils:
|
|||
if isinstance(net, nn.SequentialCell) and change:
|
||||
net.cell_list = list(net.cells())
|
||||
|
||||
|
||||
def convert_to_thor_net(self, net):
|
||||
"""
|
||||
This interface is used to convert a network to thor layer network, in order to calculate and store the
|
||||
|
@ -174,7 +204,7 @@ class ConvertModelUtils:
|
|||
Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If True, the level setting before
|
||||
will be overwritten. Default: True.
|
||||
will be overwritten. Default: False.
|
||||
|
||||
Returns:
|
||||
model (Object): High-Level API for Training.
|
||||
|
|
|
@ -12,8 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Model."""
|
||||
|
||||
"""High-Level API for Second Order Training or Testing.
|
||||
Second-order optimizer THOR reduces the computation workload and improves the computation speed by reducing the
|
||||
frequency of updating the second-order matrix. In order to optimize the overall performance, the ModelThor class
|
||||
is redefined to inherit the Model class provided by MindSpore. The parameter of THOR for controlling the frequency
|
||||
of updating the second-order matrix can be obtained by the ModelThor class. """
|
||||
import math
|
||||
from mindspore.train.callback import RunContext
|
||||
from mindspore import context
|
||||
|
|
|
@ -33,8 +33,6 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0.05803
|
||||
lr_decay: 4.04839
|
||||
lr_end_epoch: 53
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
damping_init: 0.02714
|
||||
damping_decay: 0.50036
|
||||
frequency: 834
|
||||
|
@ -54,8 +52,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
conv_init: "HeUniform"
|
||||
dense_init: "HeUniform"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -33,8 +33,6 @@ label_smooth_factor: 0.1
|
|||
lr_init: 0.05672
|
||||
lr_decay: 4.9687
|
||||
lr_end_epoch: 50
|
||||
lars_epsilon: 0.0
|
||||
lars_coefficient: 0.001
|
||||
damping_init: 0.02345
|
||||
damping_decay: 0.5467
|
||||
frequency: 834
|
||||
|
@ -54,8 +52,8 @@ enable_cache: False
|
|||
cache_session_id: ""
|
||||
mode_name: "GRAPH"
|
||||
acc_mode: "O0"
|
||||
conv_init: "XavierUniform"
|
||||
dense_init: "TruncatedNormal"
|
||||
conv_init: "HeUniform"
|
||||
dense_init: "HeUniform"
|
||||
all_reduce_fusion_config:
|
||||
- 85
|
||||
- 160
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
from src.model_utils.config import config
|
||||
|
||||
|
||||
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
||||
|
@ -208,6 +209,8 @@ class ResidualBlock(nn.Cell):
|
|||
|
||||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
|
||||
self.bn3 = _bn(out_channel)
|
||||
if config.optimizer == "Thor":
|
||||
self.bn3 = _bn_last(out_channel)
|
||||
if self.se_block:
|
||||
self.se_global_pool = P.ReduceMean(keep_dims=False)
|
||||
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.train.model import Model
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.nn.optim import thor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
|
||||
|
@ -41,7 +42,7 @@ from tests.st.networks.models.resnet50.src_thor.config import config as thor_con
|
|||
from tests.st.networks.models.resnet50.src_thor.dataset import create_dataset as create_dataset_thor
|
||||
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
|
||||
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor
|
||||
from tests.st.networks.models.resnet50.src_thor.thor import THOR
|
||||
|
||||
|
||||
MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
|
||||
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"
|
||||
|
@ -268,8 +269,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
|
|||
damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004)
|
||||
# optimizer
|
||||
split_indices = [26, 53]
|
||||
opt = THOR(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
|
||||
thor_config.batch_size, split_indices=split_indices)
|
||||
opt = thor(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
|
||||
thor_config.batch_size, split_indices=split_indices, frequency=thor_config.frequency)
|
||||
|
||||
# evaluation network
|
||||
dist_eval_network = ClassifyCorrectCell(net)
|
||||
|
|
Loading…
Reference in New Issue