add comments for thor api

This commit is contained in:
sl_wang 2021-06-10 11:59:36 +08:00
parent 9852cced86
commit 54890b88fc
7 changed files with 193 additions and 58 deletions

View File

@ -552,7 +552,7 @@ class EmbeddingThor(Cell):
Tensor of output shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. Tensor of output shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
Examples: Examples:
>>> net = nn.Embedding(20000, 768, True) >>> net = nn.EmbeddingThor(20000, 768, True)
>>> input_data = Tensor(np.ones([8, 128]), mindspore.int32) >>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
>>> >>>
>>> # Maps the input word IDs to word embedding. >>> # Maps the input word IDs to word embedding.

View File

@ -27,7 +27,7 @@ from mindspore import context
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
from mindspore.nn.wrap import DistributedGradReducer from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUntils from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer # Enumerates types of Layer
@ -101,6 +101,17 @@ def clip_gradient(enable_clip_grad, gradients):
C0 = 16 C0 = 16
def _check_param(momentum, frequency, lr, cls_name):
"""Check param."""
Validator.check_value_type("momentum", momentum, [float], cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
Validator.check_value_type("frequency", frequency, [int], cls_name)
if isinstance(frequency, int) and frequency < 2:
raise ValueError("frequency should be at least 2, but got frequency {}".format(frequency))
Validator.check_value_type("learning rate", lr, [Tensor], cls_name)
def caculate_device_shape(matrix_dim, channel, is_a): def caculate_device_shape(matrix_dim, channel, is_a):
ll = (0) ll = (0)
if is_a: if is_a:
@ -188,13 +199,103 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
frequency=100): frequency=100):
r""" r"""
Updates gradients by the THOR algorithm. Updates gradients by second-order algorithm--THOR.
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation `THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_ <https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
The updating formulas are as follows,
.. math::
\begin{array}{ll} \\
A_i = a_i{a_i}^T \\
G_i = D_{s_i}{ D_{s_i}}^T \\
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
w_i = w_i - \alpha * m_i \\
\end{array}
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer,
:math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer,
:math:`\beta` represents momentum, :math:`I` represents the identity matrix,
:math:`\overline A` represents the transpose of matrix A,
:math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer,
:math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate'
Note:
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
but the gradient centralization can only be applied to the parameters of the convolution layer.
If the parameters of the non convolution layer are set to True, an error will be reported.
To improve parameter groups performance, the customized order of parameters can be supported.
Args:
net (Cell): The training network.
learning_rate (Tensor): A value for the learning rate.
damping (Tensor): A value for the damping.
momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
default value. Default: 1.0.
batch_size (int): The size of a batch. Default: 32
use_nesterov (bool): Enable Nesterov momentum. Default: False.
decay_filter (function): A function to determine which layers the weight decay applied to. And it
only works when the weight_decay > 0. Default: lambda x: x.name not in []
split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
is 27~53. Default: None
enable_clip_grad (bool): Whether to clip the gradients. Default: False
frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1),
A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and
$A^{-1}/G^{-1}$ to update weights. Default: 100.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[bool], all elements are True.
Raises:
TypeError: If `learning_rate` is not Tensor.
TypeError: If `loss_scale`,`momentum` or `frequency` is not a float.
TypeError: If `weight_decay` is neither float nor int.
TypeError: If `use_nesterov` is not a bool.
ValueError: If `loss_scale` is less than or equal to 0.
ValueError: If `weight_decay` or `momentum` is less than 0.
ValueError: If `frequency` is not int.
ValueError: If `frequency` is less than 2.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> net = Net()
>>> optim = thor(net, lr=Tensor(1e-3), damping=Tensor(1e-3), momentum=0.9)
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
... loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False)
>>> model.train(config.epoch_size, dataset, callbacks=cb, sink_size=100, dataset_sink_mode=True)
""" """
context.set_context(max_call_depth=10000) context.set_context(max_call_depth=10000)
ConvertNetUntils().convert_to_thor_net(net) ConvertNetUtils().convert_to_thor_net(net)
if context.get_context("device_target") == "Ascend": if context.get_context("device_target") == "Ascend":
return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter, return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency) split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
@ -212,9 +313,7 @@ class ThorGpu(Optimizer):
enable_clip_grad=False, frequency=100): enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters()) params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale) super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name) _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov) self.use_nesterov = Validator.check_bool(use_nesterov)
@ -448,18 +547,14 @@ class ThorGpu(Optimizer):
matrix_a = matrix_a_allreduce[thor_layer_count] matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count]
g = self.update_gradient(matrix_g, g, matrix_a) g = self.update_gradient(matrix_g, g, matrix_a)
fake_a = self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_a[thor_layer_count], matrix_a)
fake_g = self.assign(self.matrix_g[thor_layer_count], matrix_g) self.assign(self.matrix_g[thor_layer_count], matrix_g)
g = F.depend(g, fake_a)
g = F.depend(g, fake_g)
g = self._reshape_gradient(conv_layer_count, g, g_shape) g = self._reshape_gradient(conv_layer_count, g, g_shape)
elif layer_type == Embedding: elif layer_type == Embedding:
matrix_a = matrix_a_allreduce[thor_layer_count] matrix_a = matrix_a_allreduce[thor_layer_count]
matrix_g = matrix_g_allreduce[thor_layer_count] matrix_g = matrix_g_allreduce[thor_layer_count]
fake_a = self.assign(self.matrix_a[thor_layer_count], matrix_a) self.assign(self.matrix_a[thor_layer_count], matrix_a)
fake_g = self.assign(self.matrix_g[thor_layer_count], matrix_g) self.assign(self.matrix_g[thor_layer_count], matrix_g)
g = F.depend(g, fake_a)
g = F.depend(g, fake_g)
temp_a = self.expand(matrix_a, 1) temp_a = self.expand(matrix_a, 1)
g = self.mul(temp_a, g) g = self.mul(temp_a, g)
g = self.matmul(g, matrix_g) g = self.matmul(g, matrix_g)
@ -507,8 +602,7 @@ class ThorAscend(Optimizer):
decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100): decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
params = filter(lambda x: x.requires_grad, net.get_parameters()) params = filter(lambda x: x.requires_grad, net.get_parameters())
super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale) super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0: _check_param(momentum, frequency, learning_rate, self.__class__.__name__)
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros') self.moments = self.params.clone(prefix="moments", init='zeros')
@ -845,10 +939,8 @@ class ThorAscend(Optimizer):
"""process thor graph fc layer""" """process thor graph fc layer"""
temp_a = matrix_a_allreduce[thor_layer_count] temp_a = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count]
fake_a = self.assign(self.matrix_a_cov[thor_layer_count], temp_a) self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
fake_g = self.assign(self.matrix_g_cov[thor_layer_count], temp_g) self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
g = F.depend(g, fake_a)
g = F.depend(g, fake_g)
temp_a = self.cast(temp_a, mstype.float16) temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16) temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16) g = self.cast(g, mstype.float16)
@ -925,10 +1017,8 @@ class ThorAscend(Optimizer):
if layer_type == Embedding: if layer_type == Embedding:
temp_a_ori = matrix_a_allreduce[thor_layer_count] temp_a_ori = matrix_a_allreduce[thor_layer_count]
temp_g = matrix_g_allreduce[thor_layer_count] temp_g = matrix_g_allreduce[thor_layer_count]
fake_a = self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori) self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
fake_g = self.assign(self.matrix_g_cov[thor_layer_count], temp_g) self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
g = F.depend(g, fake_a)
g = F.depend(g, fake_g)
temp_a = self.expand(temp_a_ori, 1) temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g) g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16) temp_g = self.cast(temp_g, mstype.float16)
@ -982,12 +1072,9 @@ class ThorAscend(Optimizer):
temp_a = self.cast(temp_a, mstype.float16) temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16) temp_g = self.cast(temp_g, mstype.float16)
g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max) g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
fake_a = self.assign(self.matrix_a[thor_layer_count], temp_a) self.assign(self.matrix_a[thor_layer_count], temp_a)
fake_g = self.assign(self.matrix_g[thor_layer_count], temp_g) self.assign(self.matrix_g[thor_layer_count], temp_g)
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max) self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
g = F.depend(g, fake_a)
g = F.depend(g, fake_g)
g = F.depend(g, fake_max)
new_grads = new_grads + (g,) new_grads = new_grads + (g,)
gradients = new_grads gradients = new_grads
else: else:

View File

@ -85,7 +85,7 @@ class CusCholeskyTrsm(PrimitiveWithInfer):
Examples: Examples:
>>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32) >>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32)
>>> cus_choleskytrsm = ops.CusCholeskyTrsm() >>> cus_choleskytrsm = ops.CusCholeskyTrsm()
>>> output = matmul(input_x) >>> output = cus_choleskytrsm(input_x)
""" """
@prim_attr_register @prim_attr_register

View File

@ -14,6 +14,6 @@
# ============================================================================ # ============================================================================
"""convert to second order related classes and functions.""" """convert to second order related classes and functions."""
from .convert_utils import ConvertNetUntils, ConvertModelUtils from .convert_utils import ConvertNetUtils, ConvertModelUtils
__all__ = ["ConvertNetUntils", "ConvertModelUtils"] __all__ = ["ConvertNetUtils", "ConvertModelUtils"]

View File

@ -13,29 +13,28 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
convert utils for second order optimizer: thor Conversion interface for second-order optimizer thor
""" """
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
class ConvertNetUntils(): class ConvertNetUtils():
""" """
Convert net to thor layer net Convert net to thor layer net
""" """
def __init__(self): def __init__(self):
self._convert_method_map = {nn.Dense: ConvertNetUntils._convert_dense, self._convert_method_map = {nn.Dense: ConvertNetUtils._convert_dense,
nn.Embedding: ConvertNetUntils._convert_embedding, nn.Embedding: ConvertNetUtils._convert_embedding,
nn.Conv2d: ConvertNetUntils._convert_conv2d} nn.Conv2d: ConvertNetUtils._convert_conv2d}
@staticmethod @staticmethod
def _convert_dense(subcell): def _convert_dense(subcell):
""" """
convert dense cell to second_order cell Convert dense cell to second-order cell
""" """
weight = subcell.weight weight = subcell.weight
act_name = None act_name = None
if subcell.activation_flag: if subcell.activation_flag:
@ -69,7 +68,7 @@ class ConvertNetUntils():
@staticmethod @staticmethod
def _convert_embedding(subcell): def _convert_embedding(subcell):
""" """
convert embedding cell to second_order cell Convert embedding cell to second-order cell
""" """
new_subcell = nn.EmbeddingThor(vocab_size=subcell.vocab_size, new_subcell = nn.EmbeddingThor(vocab_size=subcell.vocab_size,
embedding_size=subcell.embedding_size, embedding_size=subcell.embedding_size,
@ -81,7 +80,7 @@ class ConvertNetUntils():
@staticmethod @staticmethod
def _convert_conv2d(subcell): def _convert_conv2d(subcell):
""" """
convert conv2d cell to second_order cell Convert conv2d cell to second-order cell
""" """
out_channel = subcell.out_channels out_channel = subcell.out_channels
in_channel = subcell.in_channels in_channel = subcell.in_channels
@ -99,7 +98,7 @@ class ConvertNetUntils():
def _convert_to_thor_net(self, net): def _convert_to_thor_net(self, net):
""" """
convert net to thor net Convert net to thor net
""" """
cells = net.name_cells() cells = net.name_cells()
change = False change = False
@ -131,25 +130,76 @@ class ConvertNetUntils():
def convert_to_thor_net(self, net): def convert_to_thor_net(self, net):
""" """
api for convert net to thor net This interface is used to convert a network to thor layer network, in order to calculate and store the
second-order information matrix.
Notes:
This interface is automatically called by the second-order optimizer thor.
Args:
net (Cell): network to be trained by the second-order optimizer thor.
Examples:
>>> ConvertNetUtils().convert_to_thor_net(net)
""" """
net.update_cell_prefix() net.update_cell_prefix()
self._convert_to_thor_net(net) self._convert_to_thor_net(net)
net.update_cell_type("second_order") net.update_cell_type("second-order")
class ConvertModelUtils(): class ConvertModelUtils():
""" """
convert model to thor model utils Convert model to thor model.
""" """
@staticmethod @staticmethod
def convert_to_thor_model(model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0", def convert_to_thor_model(model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0",
loss_scale_manager=None, keep_batchnorm_fp32=False): loss_scale_manager=None, keep_batchnorm_fp32=False):
"""
This interface is used to convert model to thor model.
Args:
model (Object): High-Level API for Training.
`Model` groups layers into an object with training features.
network (Cell): A training network.
loss_fn (Cell): Objective function. Default: None.
optimizer (Cell): Optimizer used to updating the weights. Default: None.
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
training. eg: {'accuracy', 'recall'}. Default: None.
amp_level (str): Level for mixed precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
- auto: Set level to recommended level in different devices. O2 is recommended on GPU, O3 is
recommended on Ascend. The recommended level is based on the expert experience, cannot
always generalize. User should specify the level for special network.
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled.
Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If True, the level setting before
will be overwritten. Default: True.
Returns:
model (Object): High-Level API for Training.
`Model` groups layers into an object with training features.
Examples:
>>> from mindspore.nn.optim import thor
>>> from mindspore.train.model import Model
>>> from mindspore.train.loss_scale_manager import FixedLossScaleManager
>>>
>>> net = Net()
>>> loss_manager = FixedLossScaleManager(128, drop_overflow_update=False)
>>> opt = thor(net, lr, damping, momentum=0.9, weight_decay=1e-4, loss_scale=128, batch_size=32,
... frequency=100)
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
... amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
... metrics={'acc'}, amp_level="O2",
... loss_scale_manager=loss_manager,
... keep_batchnorm_fp32=False)
""" """
api for convert model to thor model
"""
optim_name = type(optimizer).__name__ optim_name = type(optimizer).__name__
if optim_name in ("ThorAscend", "ThorGpu"): if optim_name in ("ThorAscend", "ThorGpu"):
from .model_thor import ModelThor from .model_thor import ModelThor

View File

@ -53,6 +53,7 @@ class DatasetHelper:
If sink_size=-1, sink the complete dataset for each epoch. If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1. If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1. epoch_num (int): Control the number of epoch data to send. Default: 1.
iter_first_order (int): Control the number of steps first_order graph to execute. Default: 1.
Examples: Examples:
>>> dataset_helper = DatasetHelper(dataset) >>> dataset_helper = DatasetHelper(dataset)

View File

@ -26,7 +26,7 @@ from mindspore import context
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
from mindspore.nn.wrap import DistributedGradReducer from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUntils from mindspore.train.train_thor.convert_utils import ConvertNetUtils
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
# Enumerates types of Layer # Enumerates types of Layer
@ -147,7 +147,7 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None): use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
context.set_context(max_call_depth=10000) context.set_context(max_call_depth=10000)
ConvertNetUntils().convert_to_thor_net(net) ConvertNetUtils().convert_to_thor_net(net)
return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter, return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices) split_indices=split_indices)
@ -486,12 +486,9 @@ class THOR_Ascend(Optimizer):
temp_max = self.mul(temp_max, self.batch_size / A_normalizer) temp_max = self.mul(temp_max, self.batch_size / A_normalizer)
g = self.cube_matmul_left(temp_g, g) g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max) g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[thor_layer_count], temp_a) self.assign(self.matrix_A[thor_layer_count], temp_a)
fake_G = self.assign(self.matrix_G[thor_layer_count], temp_g) self.assign(self.matrix_G[thor_layer_count], temp_g)
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max) self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
if i == 159: if i == 159:
new_grads = new_grads + (g, gradients[i + 1]) new_grads = new_grads + (g, gradients[i + 1])
else: else: