fix switch_layer_issues

This commit is contained in:
panyifeng 2020-08-19 12:58:51 +08:00
parent e711aecdc2
commit f9f3cd7ce0
6 changed files with 126 additions and 8 deletions

View File

@ -462,7 +462,8 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
std::vector<FuncGraphPtr> graphs{};
auto graphs_cnode = sw->input(2)->cast<CNodePtr>();
auto &graphs_inputs = graphs_cnode->inputs();
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(graphs_inputs[1])) {
if (IsPrimitiveCNode(graphs_cnode, prim::kPrimMakeTuple) && graphs_inputs.size() >= 2 &&
IsValueNode<FuncGraph>(graphs_inputs[1])) {
(void)std::transform(graphs_inputs.begin() + 1, graphs_inputs.end(), std::back_inserter(graphs),
[](const AnfNodePtr &vnode) { return GetValueNode<FuncGraphPtr>(vnode); });
}

View File

@ -89,6 +89,7 @@ class GetItemTransformACrossGraph {
ss << idx;
auto new_fg_outer = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
fg->manager()->AddFuncGraph(new_fg_outer);
auto output_outer = new_fg_outer->output();
if (!IsValueNode<FuncGraph>(output_outer)) {
MS_LOG(WARNING) << "Output of outer graph should be a func_graph";
@ -486,7 +487,7 @@ class IncorporateGetitemSwitchLayerA : public AnfVisitor {
switch_layer_ = inputs[0];
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_));
}
if (is_in_switch_ && cnode->size() > 2) {
if (is_in_switch_ && cnode->size() >= 2) {
auto &inputs = cnode->inputs();
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),
@ -578,7 +579,7 @@ class IncorporateGetitemSwitchLayerB : public AnfVisitor {
switch_layer_call_ = inputs[0];
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outer_call_args_));
}
if (is_in_switch_ && cnode->size() > 2) {
if (is_in_switch_ && cnode->size() >= 2) {
auto &inputs = cnode->inputs();
if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple) && IsValueNode<FuncGraph>(inputs[1])) {
(void)std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(graphs_),

View File

@ -36,11 +36,10 @@ class SwitchLayerDeferInline : public AnfVisitor {
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
for (auto elem : tuple->elements()) {
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
if (abstract == nullptr) {
return nullptr;
}
if (abstract != nullptr) {
*(abstract->func_graph()->switch_layer_input()) = true;
}
}
return nullptr;
}
};

View File

@ -137,6 +137,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.arithmetic_simplify2_,
irpass.same_eliminate_,
irpass.check_bprop_eliminate_,
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});

View File

@ -16,6 +16,7 @@
#include "abstract/param_validator.h"
#include "abstract/infer_functions.h"
#include "abstract/abstract_function.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
@ -121,12 +122,18 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<AbstractFunction>()) {
MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got "
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element.";
}
}
auto b = branches[0];
// Return AbstractFuncUnion, otherwise the switch_layer will be replaced by branches[0]
// which will cancel the out of bound checking for index
if (branches.size() == 1) {
AbstractFuncAtomPtrList func_list{b->cast<AbstractFuncAtomPtr>()};
return std::make_shared<AbstractFuncUnion>(func_list);
}
for (size_t i = 1; i < branches.size(); i++) {
b = b->Join(branches[i]);
}

View File

@ -444,6 +444,86 @@ def test_index_to_switch_layer():
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_parser_switch_layer_switch_in_bprop():
class OneInputBprop(nn.Cell):
def __init__(self, funcs):
super(OneInputBprop, self).__init__()
self.op = P.ReLU()
self.funcs = funcs
def construct(self, i, x):
return self.op(x)
def bprop(self, i, x, out, dout):
return i, self.funcs[i](x, dout)
class Add(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.TensorAdd()
def construct(self, x, y):
return self.op(x, y)
class Mul(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
def construct(self, x, y):
return self.op(x, y)
func1 = Add()
func2 = Mul()
funcs = (func1, func2)
net = OneInputBprop(funcs)
input1 = Tensor(np.ones([2, 2]).astype(np.float32))
grad = Tensor(np.random.randn(2, 2).astype(np.float32))
i = Tensor(1, mstype.int32)
grad_net = C.grad_all_with_sens(net)
grad_net(i, input1, grad)
def test_parser_switch_layer_inputs_tuple():
class TwoInputTupleFinalNet(nn.Cell):
def __init__(self, funcs):
super().__init__()
self.funcs = funcs
def construct(self, i, inputa, inputb):
inputs = (inputa, inputb)
x = self.funcs[i](inputs)
return x
class Add(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.TensorAdd()
def construct(self, x):
y = self.op(x[0], x[1])
return self.op(x[0], y)
class Mul(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
def construct(self, x):
y = self.op(x[0], x[1])
return self.op(x[0], y)
func1 = Add()
func2 = Mul()
funcs = (func1, func2)
net = TwoInputTupleFinalNet(funcs)
input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)
grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
back_net = C.grad_all_with_sens(net)
back_out = back_net(i, input1, input2, grad)
def test_switch_layer_with_single_prim():
class SwitchLayerCell(nn.Cell):
def __init__(self):
@ -494,6 +574,35 @@ def test_switch_layer_env_eliminate():
net2(x, i)
def test_switch_layer_single_layer():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
self.funs = (self.conv,)
def construct(self, x, index):
x = self.funs[index](x)
return x
class NetGrad(nn.Cell):
def __init__(self, net):
super(NetGrad, self).__init__()
self.grad_op = C.GradOperation('grad', get_by_list=True, sens_param=False)
self.net = net
self.weights = ParameterTuple(self.net.trainable_params())
def construct(self, x, index):
weights = self.weights
grad = self.grad_op(self.net, weights)(x, index)
return grad
net = Net()
net2 = NetGrad(net)
x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
i = Tensor(1, ms.int32)
net2(x, i)
def test_control_depend_check():
with pytest.raises(TypeError) as e:
P.ControlDepend(0.0)