forked from mindspore-Ecosystem/mindspore
!18126 add comments for master thor api
Merge pull request !18126 from wangshuangling/master
This commit is contained in:
commit
22181af515
|
@ -552,7 +552,7 @@ class EmbeddingThor(Cell):
|
|||
Tensor of output shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Embedding(20000, 768, True)
|
||||
>>> net = nn.EmbeddingThor(20000, 768, True)
|
||||
>>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
|
||||
>>>
|
||||
>>> # Maps the input word IDs to word embedding.
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore import context
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
|
||||
from mindspore.nn.wrap import DistributedGradReducer
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUntils
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
# Enumerates types of Layer
|
||||
|
@ -101,6 +101,17 @@ def clip_gradient(enable_clip_grad, gradients):
|
|||
C0 = 16
|
||||
|
||||
|
||||
def _check_param(momentum, frequency, lr, cls_name):
|
||||
"""Check param."""
|
||||
Validator.check_value_type("momentum", momentum, [float], cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
Validator.check_value_type("frequency", frequency, [int], cls_name)
|
||||
if isinstance(frequency, int) and frequency < 2:
|
||||
raise ValueError("frequency should be at least 2, but got frequency {}".format(frequency))
|
||||
Validator.check_value_type("learning rate", lr, [Tensor], cls_name)
|
||||
|
||||
|
||||
def caculate_device_shape(matrix_dim, channel, is_a):
|
||||
ll = (0)
|
||||
if is_a:
|
||||
|
@ -188,13 +199,103 @@ def thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0
|
|||
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False,
|
||||
frequency=100):
|
||||
r"""
|
||||
Updates gradients by the THOR algorithm.
|
||||
Updates gradients by second-order algorithm--THOR.
|
||||
|
||||
Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:
|
||||
|
||||
`THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation
|
||||
<https://www.aaai.org/AAAI21Papers/AAAI-6611.ChenM.pdf>`_
|
||||
|
||||
The updating formulas are as follows,
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
A_i = a_i{a_i}^T \\
|
||||
G_i = D_{s_i}{ D_{s_i}}^T \\
|
||||
m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\
|
||||
w_i = w_i - \alpha * m_i \\
|
||||
\end{array}
|
||||
|
||||
:math:`D_{s_i}` represents the derivative of the loss function of the output of the i-th layer,
|
||||
:math:`a_{i-1}` represents the input of i-th layer,and which is the activations of previous layer,
|
||||
:math:`\beta` represents momentum, :math:`I` represents the identity matrix,
|
||||
:math:`\overline A` represents the transpose of matrix A,
|
||||
:math:`\lambda` represents 'damping', :math:`g_i` represents gradients of the i-th layer,
|
||||
:math:`\otimes` represents Kronecker product, :math:`\alpha` represents 'learning rate'
|
||||
|
||||
Note:
|
||||
When separating parameter groups, the weight decay in each group will be applied on the parameters if the
|
||||
weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
|
||||
on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
|
||||
|
||||
When separating parameter groups, if you want to centralize the gradient, set grad_centralization to True,
|
||||
but the gradient centralization can only be applied to the parameters of the convolution layer.
|
||||
If the parameters of the non convolution layer are set to True, an error will be reported.
|
||||
|
||||
To improve parameter groups performance, the customized order of parameters can be supported.
|
||||
|
||||
Args:
|
||||
net (Cell): The training network.
|
||||
|
||||
learning_rate (Tensor): A value for the learning rate.
|
||||
|
||||
damping (Tensor): A value for the damping.
|
||||
|
||||
momentum (float): Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.
|
||||
|
||||
weight_decay (int, float): Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.
|
||||
|
||||
loss_scale (float): A value for the loss scale. It must be greater than 0.0. In general, use the
|
||||
default value. Default: 1.0.
|
||||
|
||||
batch_size (int): The size of a batch. Default: 32
|
||||
|
||||
use_nesterov (bool): Enable Nesterov momentum. Default: False.
|
||||
|
||||
decay_filter (function): A function to determine which layers the weight decay applied to. And it
|
||||
only works when the weight_decay > 0. Default: lambda x: x.name not in []
|
||||
|
||||
split_indices (list): Set allreduce fusion strategy by A/G layer indices . Only works when distributed
|
||||
computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set
|
||||
to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other
|
||||
is 27~53. Default: None
|
||||
|
||||
enable_clip_grad (bool): Whether to clip the gradients. Default: False
|
||||
|
||||
frequency(int): The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1),
|
||||
A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and
|
||||
$A^{-1}/G^{-1}$ to update weights. Default: 100.
|
||||
|
||||
Inputs:
|
||||
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
|
||||
|
||||
Outputs:
|
||||
tuple[bool], all elements are True.
|
||||
|
||||
Raises:
|
||||
TypeError: If `learning_rate` is not Tensor.
|
||||
TypeError: If `loss_scale`,`momentum` or `frequency` is not a float.
|
||||
TypeError: If `weight_decay` is neither float nor int.
|
||||
TypeError: If `use_nesterov` is not a bool.
|
||||
ValueError: If `loss_scale` is less than or equal to 0.
|
||||
ValueError: If `weight_decay` or `momentum` is less than 0.
|
||||
ValueError: If `frequency` is not int.
|
||||
ValueError: If `frequency` is less than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> optim = thor(net, lr=Tensor(1e-3), damping=Tensor(1e-3), momentum=0.9)
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
... loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False)
|
||||
>>> model.train(config.epoch_size, dataset, callbacks=cb, sink_size=100, dataset_sink_mode=True)
|
||||
|
||||
"""
|
||||
context.set_context(max_call_depth=10000)
|
||||
ConvertNetUntils().convert_to_thor_net(net)
|
||||
ConvertNetUtils().convert_to_thor_net(net)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
return ThorAscend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
|
||||
split_indices=split_indices, enable_clip_grad=enable_clip_grad, frequency=frequency)
|
||||
|
@ -212,9 +313,7 @@ class ThorGpu(Optimizer):
|
|||
enable_clip_grad=False, frequency=100):
|
||||
params = filter(lambda x: x.requires_grad, net.get_parameters())
|
||||
super(ThorGpu, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.use_nesterov = Validator.check_bool(use_nesterov)
|
||||
|
@ -448,18 +547,14 @@ class ThorGpu(Optimizer):
|
|||
matrix_a = matrix_a_allreduce[thor_layer_count]
|
||||
matrix_g = matrix_g_allreduce[thor_layer_count]
|
||||
g = self.update_gradient(matrix_g, g, matrix_a)
|
||||
fake_a = self.assign(self.matrix_a[thor_layer_count], matrix_a)
|
||||
fake_g = self.assign(self.matrix_g[thor_layer_count], matrix_g)
|
||||
g = F.depend(g, fake_a)
|
||||
g = F.depend(g, fake_g)
|
||||
self.assign(self.matrix_a[thor_layer_count], matrix_a)
|
||||
self.assign(self.matrix_g[thor_layer_count], matrix_g)
|
||||
g = self._reshape_gradient(conv_layer_count, g, g_shape)
|
||||
elif layer_type == Embedding:
|
||||
matrix_a = matrix_a_allreduce[thor_layer_count]
|
||||
matrix_g = matrix_g_allreduce[thor_layer_count]
|
||||
fake_a = self.assign(self.matrix_a[thor_layer_count], matrix_a)
|
||||
fake_g = self.assign(self.matrix_g[thor_layer_count], matrix_g)
|
||||
g = F.depend(g, fake_a)
|
||||
g = F.depend(g, fake_g)
|
||||
self.assign(self.matrix_a[thor_layer_count], matrix_a)
|
||||
self.assign(self.matrix_g[thor_layer_count], matrix_g)
|
||||
temp_a = self.expand(matrix_a, 1)
|
||||
g = self.mul(temp_a, g)
|
||||
g = self.matmul(g, matrix_g)
|
||||
|
@ -507,8 +602,7 @@ class ThorAscend(Optimizer):
|
|||
decay_filter=lambda x: x.name not in [], split_indices=None, enable_clip_grad=False, frequency=100):
|
||||
params = filter(lambda x: x.requires_grad, net.get_parameters())
|
||||
super(ThorAscend, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
if isinstance(momentum, float) and momentum < 0.0:
|
||||
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
|
||||
_check_param(momentum, frequency, learning_rate, self.__class__.__name__)
|
||||
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
|
||||
self.params = self.parameters
|
||||
self.moments = self.params.clone(prefix="moments", init='zeros')
|
||||
|
@ -845,10 +939,8 @@ class ThorAscend(Optimizer):
|
|||
"""process thor graph fc layer"""
|
||||
temp_a = matrix_a_allreduce[thor_layer_count]
|
||||
temp_g = matrix_g_allreduce[thor_layer_count]
|
||||
fake_a = self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
|
||||
fake_g = self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
|
||||
g = F.depend(g, fake_a)
|
||||
g = F.depend(g, fake_g)
|
||||
self.assign(self.matrix_a_cov[thor_layer_count], temp_a)
|
||||
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
|
||||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g = self.cast(g, mstype.float16)
|
||||
|
@ -925,10 +1017,8 @@ class ThorAscend(Optimizer):
|
|||
if layer_type == Embedding:
|
||||
temp_a_ori = matrix_a_allreduce[thor_layer_count]
|
||||
temp_g = matrix_g_allreduce[thor_layer_count]
|
||||
fake_a = self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
|
||||
fake_g = self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
|
||||
g = F.depend(g, fake_a)
|
||||
g = F.depend(g, fake_g)
|
||||
self.assign(self.matrix_a_cov[thor_layer_count], temp_a_ori)
|
||||
self.assign(self.matrix_g_cov[thor_layer_count], temp_g)
|
||||
temp_a = self.expand(temp_a_ori, 1)
|
||||
g = self.mul(temp_a, g)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
|
@ -982,12 +1072,9 @@ class ThorAscend(Optimizer):
|
|||
temp_a = self.cast(temp_a, mstype.float16)
|
||||
temp_g = self.cast(temp_g, mstype.float16)
|
||||
g, temp_max = self._get_second_grad_by_matmul(i, temp_a, temp_g, g, temp_max)
|
||||
fake_a = self.assign(self.matrix_a[thor_layer_count], temp_a)
|
||||
fake_g = self.assign(self.matrix_g[thor_layer_count], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
|
||||
g = F.depend(g, fake_a)
|
||||
g = F.depend(g, fake_g)
|
||||
g = F.depend(g, fake_max)
|
||||
self.assign(self.matrix_a[thor_layer_count], temp_a)
|
||||
self.assign(self.matrix_g[thor_layer_count], temp_g)
|
||||
self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
|
||||
new_grads = new_grads + (g,)
|
||||
gradients = new_grads
|
||||
else:
|
||||
|
|
|
@ -85,7 +85,7 @@ class CusCholeskyTrsm(PrimitiveWithInfer):
|
|||
Examples:
|
||||
>>> input_x = Tensor(np.ones(shape=[256, 256]), mindspore.float32)
|
||||
>>> cus_choleskytrsm = ops.CusCholeskyTrsm()
|
||||
>>> output = matmul(input_x)
|
||||
>>> output = cus_choleskytrsm(input_x)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
|
@ -14,6 +14,6 @@
|
|||
# ============================================================================
|
||||
"""convert to second order related classes and functions."""
|
||||
|
||||
from .convert_utils import ConvertNetUntils, ConvertModelUtils
|
||||
from .convert_utils import ConvertNetUtils, ConvertModelUtils
|
||||
|
||||
__all__ = ["ConvertNetUntils", "ConvertModelUtils"]
|
||||
__all__ = ["ConvertNetUtils", "ConvertModelUtils"]
|
||||
|
|
|
@ -13,29 +13,28 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
convert utils for second order optimizer: thor
|
||||
Conversion interface for second-order optimizer thor
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class ConvertNetUntils():
|
||||
class ConvertNetUtils():
|
||||
"""
|
||||
Convert net to thor layer net
|
||||
"""
|
||||
def __init__(self):
|
||||
self._convert_method_map = {nn.Dense: ConvertNetUntils._convert_dense,
|
||||
nn.Embedding: ConvertNetUntils._convert_embedding,
|
||||
nn.Conv2d: ConvertNetUntils._convert_conv2d}
|
||||
self._convert_method_map = {nn.Dense: ConvertNetUtils._convert_dense,
|
||||
nn.Embedding: ConvertNetUtils._convert_embedding,
|
||||
nn.Conv2d: ConvertNetUtils._convert_conv2d}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_dense(subcell):
|
||||
"""
|
||||
convert dense cell to second_order cell
|
||||
Convert dense cell to second-order cell
|
||||
"""
|
||||
|
||||
weight = subcell.weight
|
||||
act_name = None
|
||||
if subcell.activation_flag:
|
||||
|
@ -69,7 +68,7 @@ class ConvertNetUntils():
|
|||
@staticmethod
|
||||
def _convert_embedding(subcell):
|
||||
"""
|
||||
convert embedding cell to second_order cell
|
||||
Convert embedding cell to second-order cell
|
||||
"""
|
||||
new_subcell = nn.EmbeddingThor(vocab_size=subcell.vocab_size,
|
||||
embedding_size=subcell.embedding_size,
|
||||
|
@ -81,7 +80,7 @@ class ConvertNetUntils():
|
|||
@staticmethod
|
||||
def _convert_conv2d(subcell):
|
||||
"""
|
||||
convert conv2d cell to second_order cell
|
||||
Convert conv2d cell to second-order cell
|
||||
"""
|
||||
out_channel = subcell.out_channels
|
||||
in_channel = subcell.in_channels
|
||||
|
@ -99,7 +98,7 @@ class ConvertNetUntils():
|
|||
|
||||
def _convert_to_thor_net(self, net):
|
||||
"""
|
||||
convert net to thor net
|
||||
Convert net to thor net
|
||||
"""
|
||||
cells = net.name_cells()
|
||||
change = False
|
||||
|
@ -131,25 +130,76 @@ class ConvertNetUntils():
|
|||
|
||||
def convert_to_thor_net(self, net):
|
||||
"""
|
||||
api for convert net to thor net
|
||||
This interface is used to convert a network to thor layer network, in order to calculate and store the
|
||||
second-order information matrix.
|
||||
|
||||
Notes:
|
||||
This interface is automatically called by the second-order optimizer thor.
|
||||
|
||||
Args:
|
||||
net (Cell): network to be trained by the second-order optimizer thor.
|
||||
|
||||
Examples:
|
||||
>>> ConvertNetUtils().convert_to_thor_net(net)
|
||||
"""
|
||||
|
||||
net.update_cell_prefix()
|
||||
self._convert_to_thor_net(net)
|
||||
net.update_cell_type("second_order")
|
||||
net.update_cell_type("second-order")
|
||||
|
||||
|
||||
class ConvertModelUtils():
|
||||
"""
|
||||
convert model to thor model utils
|
||||
Convert model to thor model.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def convert_to_thor_model(model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0",
|
||||
loss_scale_manager=None, keep_batchnorm_fp32=False):
|
||||
"""
|
||||
This interface is used to convert model to thor model.
|
||||
|
||||
Args:
|
||||
model (Object): High-Level API for Training.
|
||||
`Model` groups layers into an object with training features.
|
||||
network (Cell): A training network.
|
||||
loss_fn (Cell): Objective function. Default: None.
|
||||
optimizer (Cell): Optimizer used to updating the weights. Default: None.
|
||||
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
|
||||
training. eg: {'accuracy', 'recall'}. Default: None.
|
||||
amp_level (str): Level for mixed precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
|
||||
- O0: Do not change.
|
||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||
- auto: Set level to recommended level in different devices. O2 is recommended on GPU, O3 is
|
||||
recommended on Ascend. The recommended level is based on the expert experience, cannot
|
||||
always generalize. User should specify the level for special network.
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled.
|
||||
Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If True, the level setting before
|
||||
will be overwritten. Default: True.
|
||||
|
||||
Returns:
|
||||
model (Object): High-Level API for Training.
|
||||
`Model` groups layers into an object with training features.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.nn.optim import thor
|
||||
>>> from mindspore.train.model import Model
|
||||
>>> from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss_manager = FixedLossScaleManager(128, drop_overflow_update=False)
|
||||
>>> opt = thor(net, lr, damping, momentum=0.9, weight_decay=1e-4, loss_scale=128, batch_size=32,
|
||||
... frequency=100)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_manager, metrics={"acc"},
|
||||
... amp_level="O2", keep_batchnorm_fp32=False)
|
||||
>>> model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
|
||||
... metrics={'acc'}, amp_level="O2",
|
||||
... loss_scale_manager=loss_manager,
|
||||
... keep_batchnorm_fp32=False)
|
||||
"""
|
||||
api for convert model to thor model
|
||||
"""
|
||||
|
||||
optim_name = type(optimizer).__name__
|
||||
if optim_name in ("ThorAscend", "ThorGpu"):
|
||||
from .model_thor import ModelThor
|
||||
|
|
|
@ -53,6 +53,7 @@ class DatasetHelper:
|
|||
If sink_size=-1, sink the complete dataset for each epoch.
|
||||
If sink_size>0, sink sink_size data for each epoch. Default: -1.
|
||||
epoch_num (int): Control the number of epoch data to send. Default: 1.
|
||||
iter_first_order (int): Control the number of steps first_order graph to execute. Default: 1.
|
||||
|
||||
Examples:
|
||||
>>> dataset_helper = DatasetHelper(dataset)
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import context
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.layer import DenseThor, Conv2dThor, EmbeddingThor
|
||||
from mindspore.nn.wrap import DistributedGradReducer
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUntils
|
||||
from mindspore.train.train_thor.convert_utils import ConvertNetUtils
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
# Enumerates types of Layer
|
||||
|
@ -147,7 +147,7 @@ def get_layer_counter(layer_type, layer_counter, params, idx):
|
|||
def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
|
||||
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
|
||||
context.set_context(max_call_depth=10000)
|
||||
ConvertNetUntils().convert_to_thor_net(net)
|
||||
ConvertNetUtils().convert_to_thor_net(net)
|
||||
|
||||
return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
|
||||
split_indices=split_indices)
|
||||
|
@ -486,12 +486,9 @@ class THOR_Ascend(Optimizer):
|
|||
temp_max = self.mul(temp_max, self.batch_size / A_normalizer)
|
||||
g = self.cube_matmul_left(temp_g, g)
|
||||
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
|
||||
fake_A = self.assign(self.matrix_A[thor_layer_count], temp_a)
|
||||
fake_G = self.assign(self.matrix_G[thor_layer_count], temp_g)
|
||||
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
|
||||
g = F.depend(g, fake_A)
|
||||
g = F.depend(g, fake_G)
|
||||
g = F.depend(g, fake_max)
|
||||
self.assign(self.matrix_A[thor_layer_count], temp_a)
|
||||
self.assign(self.matrix_G[thor_layer_count], temp_g)
|
||||
self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
|
||||
if i == 159:
|
||||
new_grads = new_grads + (g, gradients[i + 1])
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue