forked from mindspore-Ecosystem/mindspore
fix switch_layer_issues
This commit is contained in:
parent
e711aecdc2
commit
f9f3cd7ce0
|
@ -462,7 +462,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
std::vector<FuncGraphPtr> graphs{};
|
||||
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
|
||||
auto &graphs_inputs = graphs_cnode->inputs();
|
||||
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) {
|
||||
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 &&
|
||||
IsValueNode<FuncGraph>(graphs_inputs[1])) {
|
||||
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
|
||||
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
|
||||
}
|
||||
|
|
|
@ -89,6 +89,7 @@ class GetItemTransformACrossGraph {
|
|||
ss << idx;
|
||||
|
||||
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
|
||||
fg->manager()->AddFuncGraph(new_fg_outer);
|
||||
auto output_outer = new_fg_outer->output();
|
||||
if (!IsValueNode<FuncGraph>(output_outer)) {
|
||||
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
|
||||
|
@ -486,7 +487,7 @@ class IncorporateGetitemSwitchLayerA : public AnfVisitor {
|
|||
switch_layer_ = inputs[0];
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));
|
||||
}
|
||||
if (is_in_switch_ && cnode->size() > 2) {
|
||||
if (is_in_switch_ && cnode->size() >= 2) {
|
||||
auto &inputs = cnode->inputs();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
|
||||
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),
|
||||
|
@ -578,7 +579,7 @@ class IncorporateGetitemSwitchLayerB : public AnfVisitor {
|
|||
switch_layer_call_ = inputs[0];
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_));
|
||||
}
|
||||
if (is_in_switch_ && cnode->size() > 2) {
|
||||
if (is_in_switch_ && cnode->size() >= 2) {
|
||||
auto &inputs = cnode->inputs();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
|
||||
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),
|
||||
|
|
|
@ -36,10 +36,9 @@ class SwitchLayerDeferInline : public AnfVisitor {
|
|||
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
|
||||
for (auto elem : tuple->elements()) {
|
||||
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
|
||||
if (abstract == nullptr) {
|
||||
return nullptr;
|
||||
if (abstract != nullptr) {
|
||||
*(abstract->func_graph()->switch_layer_input()) = true;
|
||||
}
|
||||
*(abstract->func_graph()->switch_layer_input()) = true;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -137,6 +137,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.arithmetic_simplify2_,
|
||||
irpass.same_eliminate_,
|
||||
irpass.check_bprop_eliminate_,
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.replace_applicator_,
|
||||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
|
@ -121,12 +122,18 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
|
|||
for (size_t i = 0; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
if (!branches[i]->isa<AbstractFunction>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got "
|
||||
<< branches[i]->ToString() << " as the " << i << "th element.";
|
||||
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
|
||||
<< branches[i]->ToString() << " as the " << i << "th element.";
|
||||
}
|
||||
}
|
||||
|
||||
auto b = branches[0];
|
||||
// Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
|
||||
// which will cancel the out of bound checking for index
|
||||
if (branches.size() == 1) {
|
||||
AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
|
||||
return std::make_shared<AbstractFuncUnion>(func_list);
|
||||
}
|
||||
for (size_t i = 1; i < branches.size(); i++) {
|
||||
b = b->Join(branches[i]);
|
||||
}
|
||||
|
|
|
@ -444,6 +444,86 @@ def test_index_to_switch_layer():
|
|||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
||||
|
||||
def test_parser_switch_layer_switch_in_bprop():
|
||||
class OneInputBprop(nn.Cell):
|
||||
def __init__(self, funcs):
|
||||
super(OneInputBprop, self).__init__()
|
||||
self.op = P.ReLU()
|
||||
self.funcs = funcs
|
||||
def construct(self, i, x):
|
||||
return self.op(x)
|
||||
def bprop(self, i, x, out, dout):
|
||||
return i, self.funcs[i](x, dout)
|
||||
|
||||
class Add(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.TensorAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
|
||||
class Mul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
func1 = Add()
|
||||
func2 = Mul()
|
||||
funcs = (func1, func2)
|
||||
net = OneInputBprop(funcs)
|
||||
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
|
||||
grad = Tensor(np.random.randn(2, 2).astype(np.float32))
|
||||
i = Tensor(1, mstype.int32)
|
||||
grad_net = C.grad_all_with_sens(net)
|
||||
grad_net(i, input1, grad)
|
||||
|
||||
|
||||
def test_parser_switch_layer_inputs_tuple():
|
||||
class TwoInputTupleFinalNet(nn.Cell):
|
||||
def __init__(self, funcs):
|
||||
super().__init__()
|
||||
self.funcs = funcs
|
||||
|
||||
def construct(self, i, inputa, inputb):
|
||||
inputs = (inputa, inputb)
|
||||
x = self.funcs[i](inputs)
|
||||
return x
|
||||
|
||||
class Add(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
y = self.op(x[0], x[1])
|
||||
return self.op(x[0], y)
|
||||
|
||||
class Mul(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
y = self.op(x[0], x[1])
|
||||
return self.op(x[0], y)
|
||||
|
||||
func1 = Add()
|
||||
func2 = Mul()
|
||||
|
||||
funcs = (func1, func2)
|
||||
net = TwoInputTupleFinalNet(funcs)
|
||||
|
||||
input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
i = Tensor(1, mstype.int32)
|
||||
grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
|
||||
back_net = C.grad_all_with_sens(net)
|
||||
back_out = back_net(i, input1, input2, grad)
|
||||
|
||||
|
||||
def test_switch_layer_with_single_prim():
|
||||
class SwitchLayerCell(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -494,6 +574,35 @@ def test_switch_layer_env_eliminate():
|
|||
net2(x, i)
|
||||
|
||||
|
||||
def test_switch_layer_single_layer():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
|
||||
self.funs = (self.conv,)
|
||||
|
||||
def construct(self, x, index):
|
||||
x = self.funs[index](x)
|
||||
return x
|
||||
|
||||
class NetGrad(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(NetGrad, self).__init__()
|
||||
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False)
|
||||
self.net = net
|
||||
self.weights = ParameterTuple(self.net.trainable_params())
|
||||
|
||||
def construct(self, x, index):
|
||||
weights = self.weights
|
||||
grad = self.grad_op(self.net, weights)(x, index)
|
||||
return grad
|
||||
net = Net()
|
||||
net2 = NetGrad(net)
|
||||
x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
|
||||
i = Tensor(1, ms.int32)
|
||||
net2(x, i)
|
||||
|
||||
|
||||
def test_control_depend_check():
|
||||
with pytest.raises(TypeError) as e:
|
||||
P.ControlDepend(0.0)
|
||||
|
|
Loading…
Reference in New Issue