Insert load handling for different cases of sequence parameters.

This commit is contained in:
Margaret_wangrui 2022-08-31 09:30:00 +08:00
parent 9818117c76
commit 79281813bd
7 changed files with 287 additions and 35 deletions

View File

@ -254,6 +254,20 @@ class COMMON_EXPORT AnfAlgo {
return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
}
// Check whether the sequence node has Ref abstract.
static inline bool SequenceHasAbstractRef(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto &abs = node->abstract();
if ((abs != nullptr) && (abs->isa<abstract::AbstractSequence>())) {
auto abs_seq = abs->cast_ptr<abstract::AbstractSequence>();
AbstractBasePtrList elements = abs_seq->elements();
return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &element) {
return (element != nullptr) && element->isa<abstract::AbstractRefTensor>();
});
}
return false;
}
// Get the real output node and indexes of get item, make tuple, depend, load.
static AnfNodePtr GetTupleIndexes(const AnfNodePtr &node, std::vector<size_t> *const index_stack);
static bool IsNopNode(const AnfNodePtr &node);

View File

@ -28,6 +28,7 @@
#include "ir/param_info.h"
#include "ir/cell.h"
#include "include/common/utils/python_adapter.h"
#include "include/common/utils/anfalgo.h"
#include "abstract/abstract_value.h"
#include "frontend/parallel/costmodel_context.h"
#include "include/common/utils/parallel_context.h"
@ -1019,7 +1020,7 @@ bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &ta
// If the return value of subgraph is Ref in control flow scenarios, should run graph mode with kernelbykernel.
bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes) {
// %1 = switch(cond, func1, func2)
// %2 = %1() if the abstract of the node is AbstractRefTensor, return true.
// %2 = %1() if the abstract of the node is AbstractRefTensor or Tuple/List(AbstractRefTensor, ...), return true.
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
@ -1032,8 +1033,7 @@ bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr
auto &users = iter->second;
for (auto &user : users) {
auto &user_node = user.first;
const auto &abs = user_node->abstract();
if (abs != nullptr && abs->isa<abstract::AbstractRefTensor>()) {
if (common::AnfAlgo::HasAbstractRef(user_node) || common::AnfAlgo::SequenceHasAbstractRef(user_node)) {
return true;
}
}

View File

@ -28,6 +28,7 @@
#include "frontend/operator/composite/multitype_funcgraph.h"
#include "utils/flags.h"
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "utils/log_adapter.h"
@ -103,22 +104,13 @@ int GetSideEffectPropagate(const PrimitivePtr &prim) {
return 0;
}
// Return true if the node has Ref abstract.
bool HasAbstractRef(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto &abs = node->abstract();
return (abs != nullptr) && abs->isa<abstract::AbstractRefTensor>();
}
// Gets ref inputs and its indexes from a cnode.
RefInputs GetRefInputs(const CNodePtr &cnode) {
RefInputs ref_inputs;
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->size(); ++i) {
auto &input = cnode->inputs().at(i);
if (HasAbstractRef(input)) {
if (common::AnfAlgo::HasAbstractRef(input)) {
ref_inputs[input].push_back(i);
}
}
@ -132,14 +124,32 @@ bool HasRefInput(const CNodePtr &cnode) {
}
auto &inputs = cnode->inputs();
// Return true if any of arguments is ref.
return std::any_of(inputs.begin() + 1, inputs.end(), [](const auto &input) { return HasAbstractRef(input); });
return std::any_of(inputs.begin() + 1, inputs.end(),
[](const auto &input) { return common::AnfAlgo::HasAbstractRef(input); });
}
// Return true if cnode has tuple(ref) or list(ref).
bool HasRefSequenceInput(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->inputs().empty()) {
return false;
}
auto &inputs = cnode->inputs();
for (size_t index = 1; index < inputs.size(); ++index) {
auto input = cnode->input(index);
MS_EXCEPTION_IF_NULL(input);
if (common::AnfAlgo::SequenceHasAbstractRef(input)) {
return true;
}
}
return false;
}
// Return true if we don't need Load for the given primitive.
// i.e. keep Ref as Ref for some primitives.
bool IsKeepRef(const PrimitivePtr &prim) {
return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) ||
IsPrimitiveEquals(prim, prim::kPrimPull);
IsPrimitiveEquals(prim, prim::kPrimPull) || IsPrimitiveEquals(prim, prim::kPrimMakeTuple) ||
IsPrimitiveEquals(prim, prim::kPrimMakeList);
}
// Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
@ -177,6 +187,18 @@ prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) {
return nullptr;
}
// The cnode is non-effect-node, and the cnode is real node, and the inputs of cnode is dynamic.
bool IsNonEffectRealNodeAndInputIsDynamic(const CNodePtr &cnode) {
static const PrimitiveSet dynamic_input_node_prims = {
prim::kPrimStack, prim::kPrimConcat, prim::kPrimAddN, prim::kPrimIdentityN,
prim::kPrimSparseConcat, prim::kPrimMeshgrid, prim::kPrimDynamicStitch};
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
return false;
}
return dynamic_input_node_prims.find(prim) != dynamic_input_node_prims.end();
}
// --------------------------------------------------------------------
// SCC (Strongly Connected Components) related types.
// --------------------------------------------------------------------
@ -383,11 +405,11 @@ class SideEffectFinder {
std::vector<FuncGraphPtr> branches;
auto true_branch = GetSwitchBranch(cnode, true_index);
if (true_branch != nullptr) {
branches.emplace_back(true_branch);
(void)branches.emplace_back(true_branch);
}
auto false_branch = GetSwitchBranch(cnode, false_index);
if (false_branch != nullptr) {
branches.emplace_back(false_branch);
(void)branches.emplace_back(false_branch);
}
if (branches.empty()) {
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
@ -494,8 +516,9 @@ class SideEffectFinder {
std::vector<FuncGraphPtr> GetSwitchLayerBranches(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
constexpr size_t func_tuple_index = 2;
constexpr int recursive_level = 2;
if (cnode->size() <= func_tuple_index) {
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2);
MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(recursive_level);
}
auto func_tuple = cnode->inputs().at(func_tuple_index);
return GetGraphsFromTuple(func_tuple);
@ -528,15 +551,17 @@ class SideEffectFinder {
std::vector<FuncGraphPtr> GetGraphsFromMakeTuple(const CNodePtr &make_tuple) const {
MS_EXCEPTION_IF_NULL(make_tuple);
auto &inputs = make_tuple->inputs();
constexpr int recursive_level = 2;
if (inputs.size() <= 1) {
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2);
MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(recursive_level);
}
std::vector<FuncGraphPtr> graphs;
graphs.reserve(inputs.size() - 1);
for (size_t i = 1; i < inputs.size(); ++i) {
auto func_graph = GetValueNode<FuncGraphPtr>(inputs.at(i));
if (func_graph == nullptr) {
MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(2) << " index=" << i;
MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(recursive_level)
<< " index=" << i;
continue;
}
graphs.push_back(func_graph);
@ -596,10 +621,11 @@ class SideEffectFinder {
MS_EXCEPTION_IF_NULL(tuple_indexes);
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
constexpr int recursive_level = 2;
// Trace MakeTuple.
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
if (tuple_indexes->empty()) {
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2);
MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(recursive_level);
}
// Pop out tuple index.
auto top_index = tuple_indexes->top();
@ -647,11 +673,11 @@ class SideEffectFinder {
// %1 = J(primal)
// tuple = %1(args)
if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) {
MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(2);
MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(recursive_level);
return {EffectInfo::kDetected, false, false, false};
}
// Rare case.
MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(2);
MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(recursive_level);
return {EffectInfo::kDetected, false, false, false};
}
@ -981,6 +1007,9 @@ class SideEffectFinder {
// load is inserted inside the func_graph f.
info.load = HasRefInput(cnode);
}
if (!info.memory && IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
info.load = HasRefSequenceInput(cnode);
}
return info;
}
@ -1293,7 +1322,7 @@ class AutoMonadConverter {
// If the node has no side effects but 'no_eliminate' flag is set,
// we save it to no_eliminate_nodes and handle them late.
if (!info.memory && !info.io && IsNoEliminateNode(cnode)) {
no_eliminate_nodes_.emplace_back(cnode);
(void)no_eliminate_nodes_.emplace_back(cnode);
}
}
cnode->SetEffectHandled(true);
@ -1334,10 +1363,10 @@ class AutoMonadConverter {
AbstractBasePtrList element_abstracts;
tuple_inputs.reserve(no_eliminate_nodes_.size() + 1);
element_abstracts.reserve(no_eliminate_nodes_.size());
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &node : no_eliminate_nodes_) {
tuple_inputs.emplace_back(node);
element_abstracts.emplace_back(node->abstract());
(void)tuple_inputs.emplace_back(node);
(void)element_abstracts.emplace_back(node->abstract());
}
auto make_tuple_node = func_graph_->NewCNode(tuple_inputs);
make_tuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
@ -1351,8 +1380,9 @@ class AutoMonadConverter {
// To: return output
void ClearIsolatedNodes() const {
auto output = GetGraphOutput();
constexpr size_t attach_index = 2;
if (IsPrimitiveCNode(output, prim::kPrimDepend) &&
IsPrimitiveCNode(output->cast<CNodePtr>()->input(2), prim::kPrimStopGradient)) {
IsPrimitiveCNode(output->cast<CNodePtr>()->input(attach_index), prim::kPrimStopGradient)) {
// Replace Depend(orig_output, StopGrad) node with orig_output.
// After that, nodes may be eliminated if have no side effects.
auto &orig_output = output->cast<CNodePtr>()->input(1);
@ -1415,6 +1445,10 @@ class AutoMonadConverter {
void HandleLoad(const CNodePtr &cnode, bool update_state) {
MS_EXCEPTION_IF_NULL(cnode);
// Check if a sequence which has ref exists in the inputs of the cnode, and the cnode is a real node.
if (IsNonEffectRealNodeAndInputIsDynamic(cnode)) {
return InsertLoadForSequenceRef(cnode, update_state);
}
if (IsValueNode<Primitive>(cnode->input(0))) {
// For primitive calls that use Ref as input, insert Loads before them.
InsertLoads(cnode, update_state);
@ -1426,6 +1460,75 @@ class AutoMonadConverter {
}
}
AnfNodePtr NewItemNode(const AnfNodePtr &node, const AbstractBasePtr &seq_abs, const AbstractBasePtr &item_abs,
size_t index) {
std::vector<AnfNodePtr> item_inputs;
if (seq_abs->isa<abstract::AbstractTuple>()) {
(void)item_inputs.emplace_back(NewValueNode(prim::kPrimTupleGetItem));
} else if (seq_abs->isa<abstract::AbstractList>()) {
(void)item_inputs.emplace_back(NewValueNode(prim::kPrimListGetItem));
}
(void)item_inputs.emplace_back(node);
(void)item_inputs.emplace_back(NewValueNode(SizeToLong(index)));
auto new_item = func_graph_->NewCNode(std::move(item_inputs));
new_item->set_abstract(item_abs);
if (item_abs->isa<abstract::AbstractRefTensor>()) {
// Current u monad.
auto current_u = GetUniverse();
// Make a Load for item node.
new_item = MakeLoad(node, new_item, current_u);
}
return new_item;
}
// params = (param1, param2, ..., value)
// addn(params, xxx) non-effect-node need insert load for params.
void InsertLoadForSequenceRef(const CNodePtr &cnode, bool update_state) {
const auto &inputs = cnode->inputs();
abstract::AbstractBasePtrList new_seq_abstracts;
for (size_t index = 1; index < inputs.size(); ++index) {
const auto &input = inputs[index];
const auto &input_abs = input->abstract();
MS_EXCEPTION_IF_NULL(input_abs);
if (!input_abs->isa<abstract::AbstractTuple>() && !input_abs->isa<abstract::AbstractList>()) {
(void)new_seq_abstracts.emplace_back(input_abs);
continue;
}
// Handle the input which is sequence.
std::vector<AnfNodePtr> new_sequence_inputs;
if (input_abs->isa<abstract::AbstractTuple>()) {
(void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
} else if (input_abs->isa<abstract::AbstractList>()) {
(void)new_sequence_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
}
auto seq_abs = input_abs->cast_ptr<abstract::AbstractSequence>();
MS_EXCEPTION_IF_NULL(seq_abs);
const auto &elements = seq_abs->elements();
for (size_t item_index = 0; item_index < elements.size(); ++item_index) {
const auto &item_abs = elements[item_index];
auto item = NewItemNode(input, input_abs, item_abs, item_index);
(void)new_sequence_inputs.emplace_back(item);
(void)new_seq_abstracts.emplace_back(item->abstract());
}
auto new_seq = func_graph_->NewCNode(std::move(new_sequence_inputs));
MS_LOG(DEBUG) << "Replace the input of non-effect-node:" << cnode->DebugString()
<< " with:" << new_seq->DebugString();
if (input_abs->isa<abstract::AbstractTuple>()) {
new_seq->set_abstract(std::make_shared<abstract::AbstractTuple>(new_seq_abstracts));
} else if (input_abs->isa<abstract::AbstractList>()) {
new_seq->set_abstract(std::make_shared<abstract::AbstractList>(new_seq_abstracts));
}
manager_->SetEdge(cnode, index, new_seq);
if (update_state) {
auto current_u = GetUniverse();
// In the order_enforce phase, the cnode will be added to the updatestate to ensure the order,
// and the input of the updatestate is maintained here to 2.
// to ensure the verification of the updatestate in the relevant pass.
u_ = UpdateState(current_u, new_seq);
}
}
}
//
// Insert Loads for a primitive cnode that use Ref as input.
// for example, from:
@ -1478,18 +1581,18 @@ class AutoMonadConverter {
for (size_t index : ref_input.second) {
manager_->SetEdge(cnode, index, load);
}
loads.emplace_back(std::move(load));
(void)loads.emplace_back(std::move(load));
}
return loads;
}
CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) {
CNodePtr MakeLoad(const AnfNodePtr &node, const AnfNodePtr &ref, const AnfNodePtr &u) {
static const std::string primitive_target = "primitive_target";
// Create Load cnode.
auto load_prim = NewValueNode(prim::kPrimLoad);
auto load_cnode = func_graph_->NewCNode({load_prim, ref, u});
// Set device target for Load CNode.
std::string target = GetCNodeTarget(cnode);
std::string target = GetCNodeTarget(node);
load_cnode->set_user_data(primitive_target, std::make_shared<std::string>(target));
// Set load_cnode abstract to Tensor according the input Ref[Tensor].
auto ref_abs = dyn_cast<abstract::AbstractRefTensor>(ref->abstract());

@ -1 +1 @@
Subproject commit 5cd93e4d7af7e98e05fdcc98c901c33c971523c8
Subproject commit 763ef7f119301f465a75c7adc7b42ed60d3cfde2

View File

@ -1,4 +1,4 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2021-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -84,6 +84,7 @@ def test_auto_monad_addn_adam():
allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001)
allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001)
allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001)
context.set_context(mode=context.GRAPH_MODE)
class AutoMonadTwoAssignTwoAddnDependencyNet(Cell):
@ -257,3 +258,133 @@ def test_parameter_tuple_assign():
out = net(x)
assert out[0] == 2
assert out[1] == 0
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_tuple_assign_addn():
"""
Feature: Auto monad feature.
Description: Parameter tuple assign and addn.
Expectation: No exception.
"""
class Net(Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.addn = P.AddN()
self.param1 = Parameter(Tensor(1), name="param1")
self.param2 = Parameter(Tensor(2), name="param2")
def construct(self, x):
params = (self.param1, self.param2)
res1 = self.addn(params)
self.assign(params[0], x)
res2 = self.addn(params)
self.assign(params[1], x * 2)
res3 = self.addn(params)
res4 = params[0] + params[1]
res = (res1, res2, res3, res4)
return res
x = Tensor(3)
net = Net()
out = net(x)
assert out == (3, 5, 9, 9)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_tuple_assign_addn_inner_net():
"""
Feature: Auto monad feature.
Description: Parameter tuple assign and addn.
Expectation: No exception.
"""
class InnerNet(Cell):
def __init__(self):
super().__init__()
self.assign = P.Assign()
self.addn = P.AddN()
self.param1 = Parameter(Tensor(1), name="param1")
self.param2 = Parameter(Tensor(2), name="param2")
def construct(self, x):
params = (self.param1, self.param2)
res1 = self.addn(params)
self.assign(params[0], x)
res2 = self.addn(params)
res = (res1, res2, self.param1, self.param2)
return res
class Net(Cell):
def __init__(self):
super().__init__()
self.inner_net = InnerNet()
self.addn = P.AddN()
self.assign = P.Assign()
def construct(self, x, y):
inner_net_res = self.inner_net(x)
params = (inner_net_res[2], inner_net_res[3])
out_res1 = self.addn(params)
self.assign(inner_net_res[2], y)
out_res2 = self.addn(params)
self.assign(inner_net_res[3], 2 * y)
return out_res1 + out_res2, inner_net_res[2] + inner_net_res[3]
input_x = Tensor(3)
input_y = Tensor(5)
net = Net()
out = net(input_x, input_y)
assert out == (12, 15)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_parameter_tuple_assign_addn_inner_net_control_flow():
"""
Feature: Auto monad feature.
Description: Parameter tuple assign and addn.
Expectation: No exception.
"""
class InnerNet(Cell):
def __init__(self):
super().__init__()
self.param1 = Parameter(Tensor(1), name="param1")
self.param2 = Parameter(Tensor(2), name="param2")
def construct(self, x):
if x > 0:
return self.param1, self.param2
return self.param2, self.param1
class Net(Cell):
def __init__(self):
super().__init__()
self.inner_net = InnerNet()
self.addn = P.AddN()
self.assign = P.Assign()
def construct(self, x, y):
inner_params = self.inner_net(x)
out_res1 = self.addn(inner_params)
self.assign(inner_params[1], y)
out_res2 = self.addn(inner_params)
self.assign(inner_params[0], 2 * y)
return out_res1 + out_res2, inner_params[0] + inner_params[1]
input_x = Tensor(3)
input_y = Tensor(5)
net = Net()
out = net(input_x, input_y)
assert out == (9, 15)

View File

@ -22,6 +22,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.communication.management import get_group_size
from mindspore.context import ParallelMode
import mindspore.common._monad as monad
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -354,6 +355,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
self.load = P.Load()
def construct(self,
input_ids,
@ -399,7 +401,7 @@ class BertTrainOneStepWithLossScaleCell(nn.TrainOneStepWithLossScaleCell):
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if not overflow:
self.optimizer(grads)
return (loss, cond, scaling_sens)
return loss, cond, self.load(scaling_sens, monad.U)
class BertTrainOneStepWithLossScaleCellForAdam(nn.TrainOneStepWithLossScaleCell):

View File

@ -25,6 +25,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode
import mindspore.common._monad as monad
from mindspore.communication.management import get_group_size
from mindspore import context
from .bert_model import BertModel
@ -377,6 +378,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.load = P.Load()
def construct(self,
input_ids,
@ -433,4 +435,4 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell):
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if not overflow:
self.optimizer(grads)
return (loss, cond, scaling_sens)
return loss, cond, self.load(scaling_sens, monad.U)