support if by if not inline
add testcase of net of if by if
This commit is contained in:
parent
33fdc43f18
commit
1bd9fefd84
|
@ -20,12 +20,14 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -153,23 +155,31 @@ class InlinerBase : public AnfVisitor {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params));
|
||||
std::vector<AnfNodePtr> args;
|
||||
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
|
||||
// compare size to avoid the case that the function has default value after grad.
|
||||
// for which after renormalize, the function default value will be an input
|
||||
if (fg->parameters().size() != params.size()) {
|
||||
if (fg->parameters().size() != args.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
// Not to inline after block if it has switch call inside, to avoid switch expansion.
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
|
||||
auto has_branch_call = GraphHasBranch(fg);
|
||||
if (has_branch_call) {
|
||||
return TransformBranchCall(fg, node, args);
|
||||
}
|
||||
}
|
||||
|
||||
if (use_move_ && IsUniqueUse(fg, nullptr)) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
ReplaceParams(mng, params, fg);
|
||||
ReplaceParams(mng, args, fg);
|
||||
auto out_node = fg->output();
|
||||
mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
|
||||
return out_node;
|
||||
}
|
||||
|
||||
return InlineClone(fg, node->func_graph(), params, inputs[0]->scope());
|
||||
return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
|
||||
}
|
||||
|
||||
void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
|
||||
|
@ -197,11 +207,89 @@ class InlinerBase : public AnfVisitor {
|
|||
is_checked_ = false;
|
||||
is_recursive_ = false;
|
||||
}
|
||||
// For after block which contains branch call, delete the parameters which is not used.
|
||||
// In most cases, it may be a `Module` or other constant input.
|
||||
AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) {
|
||||
auto &fg_params = fg->parameters();
|
||||
std::vector<int> used_param_index;
|
||||
auto mng = fg->manager();
|
||||
for (size_t i = 0; i < fg_params.size(); i++) {
|
||||
if (mng->node_users()[fg_params[i]].size() != 0) {
|
||||
used_param_index.emplace_back(i);
|
||||
}
|
||||
}
|
||||
if (used_param_index.size() != fg_params.size()) {
|
||||
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
|
||||
// clone a new graph and ignore the not used parameters
|
||||
FuncGraphPtr new_fg = TransformableClone(fg);
|
||||
auto &new_fg_params = new_fg->parameters();
|
||||
std::vector<AnfNodePtr> new_params;
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
|
||||
[&new_fg_params](size_t i) { return new_fg_params[i]; });
|
||||
new_fg->set_parameters(new_params);
|
||||
std::vector<AnfNodePtr> node_inputs;
|
||||
node_inputs.push_back(NewValueNode(new_fg));
|
||||
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
|
||||
[&args](size_t i) { return args[i]; });
|
||||
return node->func_graph()->NewCNode(node_inputs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This is a try-best algorithm to find a graph which may generate branch call.
|
||||
// It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
|
||||
bool GraphHasBranch(FuncGraphPtr fg) {
|
||||
if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
|
||||
return graph_branch_cache_[fg];
|
||||
}
|
||||
bool has_branch = false;
|
||||
auto nodes = fg->nodes();
|
||||
for (auto &item : nodes) {
|
||||
if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
|
||||
auto sw_inputs = item->cast<CNodePtr>()->inputs();
|
||||
if (sw_inputs.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "switch inputs should be 4";
|
||||
}
|
||||
if (!sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1])) {
|
||||
has_branch = true;
|
||||
break;
|
||||
}
|
||||
} else if (IsCNodeGraph(item)) {
|
||||
auto cinputs = item->cast<CNodePtr>()->inputs();
|
||||
if (cinputs.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
|
||||
}
|
||||
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
|
||||
bool call_fg_has_branch = GraphHasBranch(call_fg);
|
||||
if (call_fg_has_branch) {
|
||||
has_branch = true;
|
||||
break;
|
||||
}
|
||||
} else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
|
||||
auto cinputs = item->cast<CNodePtr>()->inputs();
|
||||
if (cinputs.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
|
||||
}
|
||||
FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
|
||||
if (call_fg == nullptr) {
|
||||
continue;
|
||||
}
|
||||
bool call_fg_has_branch = GraphHasBranch(call_fg);
|
||||
if (call_fg_has_branch) {
|
||||
has_branch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
graph_branch_cache_[fg] = has_branch;
|
||||
return has_branch;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_checked_{false}, is_recursive_{false};
|
||||
bool use_move_;
|
||||
std::vector<std::pair<CriterionFuncType, bool>> criterions_;
|
||||
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
|
||||
};
|
||||
|
||||
class Inliner : public InlinerBase {
|
||||
|
|
|
@ -1029,6 +1029,12 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
|
|||
FunctionBlockPtr after_block = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
|
||||
if (MsContext::GetInstance()->backend_policy() != "ge") {
|
||||
// for backends excludes 'ge', it can handle multi graph call, use this flag to
|
||||
// generate call not inline `after_block` graph to reduce if by if switch expansion.
|
||||
after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
|
||||
}
|
||||
|
||||
// process the if-true branch
|
||||
py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
|
||||
|
|
|
@ -74,6 +74,7 @@ using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
|
|||
|
||||
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
|
||||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
|
||||
const char FUNC_GRAPH_FLAG_CORE[] = "core";
|
||||
const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
|
||||
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
||||
|
|
|
@ -42,7 +42,7 @@ def test_while_forward():
|
|||
idx = idx + 1
|
||||
return x
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
|
@ -72,7 +72,7 @@ def test_while_grad():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -99,7 +99,7 @@ def test_while_with_param_forward():
|
|||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
|
@ -124,7 +124,7 @@ def test_while_endless_case():
|
|||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = MyWhileNet()
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
end = Tensor(np.array(2), dtype=ms.int32)
|
||||
|
@ -159,7 +159,7 @@ def test_while_with_param_grad():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -187,7 +187,7 @@ def test_while_with_param_forward_with_const_branch():
|
|||
idx = idx + 1
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = while_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -224,7 +224,7 @@ def test_while_opt_endless():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -250,7 +250,7 @@ def test_no_while_call():
|
|||
out = out + idx + self.param
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = while_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -287,7 +287,7 @@ def test_while_with_param_grad_with_const_branch():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -327,7 +327,7 @@ def test_for_while_with_param_grad_with_const_branch():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -364,7 +364,7 @@ def test_for_while_with_param_grad_basic():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -401,7 +401,7 @@ def test_for_while_with_param_grad_normal():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -435,7 +435,7 @@ def test_while_with_param_basic_grad():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -469,7 +469,7 @@ def test_while_with_param_basic_grad_mul():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -504,7 +504,7 @@ def test_while_with_param_basic_grad_two():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -540,7 +540,7 @@ def test_while_with_param_basic_grad_three():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -577,7 +577,7 @@ def test_while_if_with_param_grad():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -610,7 +610,7 @@ def test_while_with_param_grad_not_enter_while():
|
|||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
idx = Tensor(np.array(3), dtype=ms.int32)
|
||||
|
@ -639,7 +639,7 @@ def test_with_param_if_by_if_forward():
|
|||
out = out + x*2
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -672,7 +672,7 @@ def test_with_param_if_by_if_grad_inputs():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -706,7 +706,7 @@ def test_with_param_if_by_if_grad_parameter():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(0), dtype=ms.int32)
|
||||
|
@ -738,7 +738,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(4), dtype=ms.int32)
|
||||
|
@ -772,7 +772,7 @@ def test_if_by_if_return_inside_grad():
|
|||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
idx = Tensor(np.array(1), dtype=ms.int32)
|
||||
|
@ -807,10 +807,342 @@ def test_if_by_if_forward():
|
|||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(4), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_if_by_if_forward_control_tuple_switch():
|
||||
"""tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
|
||||
class Branch3Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if b == x:
|
||||
b = self.add(a, b)
|
||||
else:
|
||||
b = self.add(a, x)
|
||||
return a, b, x
|
||||
class Branch2Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.net = Branch3Net()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a == x:
|
||||
a = self.mul(a, b)
|
||||
else:
|
||||
a = self.div(a, b)
|
||||
return self.net(a, b, x)
|
||||
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.net = Branch2Net()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a < b:
|
||||
a = self.add(a, b)
|
||||
else:
|
||||
a = self.sub(a, b)
|
||||
a, b, x = self.net(a, b, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_if_by_if_forward_control_inside_net():
|
||||
class Branch3Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if b == x:
|
||||
b = self.add(a, b)
|
||||
else:
|
||||
b = self.add(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
class Branch2Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.net = Branch3Net()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a == x:
|
||||
a = self.mul(a, b)
|
||||
else:
|
||||
a = self.div(a, b)
|
||||
return self.net(a, b, x)
|
||||
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.net = Branch2Net()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a < b:
|
||||
a = self.add(a, b)
|
||||
else:
|
||||
a = self.sub(a, b)
|
||||
out = self.net(a, b, x)
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
|
||||
def test_if_by_if_forward_use_namespace():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
if a < b:
|
||||
a = P.TensorAdd()(a, b)
|
||||
else:
|
||||
a = P.Sub()(a, b)
|
||||
if a == x:
|
||||
a = P.Mul()(a, b)
|
||||
else:
|
||||
a = P.RealDiv()(a, b)
|
||||
if b == x:
|
||||
b = P.TensorAdd()(a, b)
|
||||
else:
|
||||
b = P.TensorAdd()(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_if_by_if_forward_use_global_op():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
mul = P.Mul()
|
||||
div = P.RealDiv()
|
||||
if a < b:
|
||||
a = add(a, b)
|
||||
else:
|
||||
a = sub(a, b)
|
||||
if a == x:
|
||||
a = mul(a, b)
|
||||
else:
|
||||
a = div(a, b)
|
||||
if b == x:
|
||||
b = add(a, b)
|
||||
else:
|
||||
b = add(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
def test_for_with_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
for _ in range(0, 4):
|
||||
if a < b:
|
||||
a = self.add(a, b)
|
||||
else:
|
||||
b = self.sub(b, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
|
||||
def test_for_with_if_by_if_forward_namespace():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
for _ in range(0, 6):
|
||||
if a < b:
|
||||
a = P.TensorAdd()(a, b)
|
||||
else:
|
||||
b = P.Sub()(b, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
|
||||
def test_if_by_if_forward_const_branch_inner():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
mul = P.Mul()
|
||||
div = P.RealDiv()
|
||||
if a < b:
|
||||
a = add(a, b)
|
||||
else:
|
||||
a = sub(a, b)
|
||||
if 2 > 1:
|
||||
a = mul(a, b)
|
||||
else:
|
||||
a = div(a, b)
|
||||
if b == x:
|
||||
b = add(a, b)
|
||||
else:
|
||||
b = add(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_if_by_if_forward_all_const_branch():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.TensorAdd()
|
||||
self.sub = P.Sub()
|
||||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
mul = P.Mul()
|
||||
div = P.RealDiv()
|
||||
if 2 < 12:
|
||||
a = add(a, b)
|
||||
else:
|
||||
a = sub(a, b)
|
||||
if 2 > 1:
|
||||
a = mul(a, b)
|
||||
else:
|
||||
a = div(a, b)
|
||||
if 2 == 1:
|
||||
b = add(a, b)
|
||||
else:
|
||||
b = add(a, x)
|
||||
a = a * b
|
||||
out = a + b + x
|
||||
return out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if_net = MyIfByIfNet()
|
||||
net = if_net
|
||||
idx = Tensor(np.array(2), dtype=ms.float32)
|
||||
end = Tensor(np.array(3), dtype=ms.float32)
|
||||
x = Tensor(np.array(0), dtype=ms.float32)
|
||||
net(idx, end, x)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue