!12949 Add TopoSort Rhs First attribute for special CNode, such as Depend CNode with isolated nodes.

From: @zh_qh
Reviewed-by: @hwhewei,@zhunaipan
Signed-off-by: @zhunaipan
This commit is contained in:
mindspore-ci-bot 2021-03-06 14:41:35 +08:00 committed by Gitee
commit 48d4cca512
7 changed files with 30 additions and 14 deletions

View File

@ -228,13 +228,13 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
bool change = false;
auto res = DoTransform(optimizer, node, substitution);
if (res != nullptr) {
if (is_once_) {
return true;
}
change = true;
changes = true;
node = res;
}
if (change && is_once_) {
return true;
}
UpdateTransformingList(optimizer, node, &todo, change, seen);
}

View File

@ -17,14 +17,17 @@
*/
#include "pipeline/jit/parse/function_block.h"
#include <string>
#include <memory>
#include "pybind11/pybind11.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/parse/parse.h"
#include "frontend/operator/ops.h"
#include "utils/info.h"
#include "debug/trace.h"
#include "pybind11/pybind11.h"
#include "utils/utils.h"
namespace mindspore {
namespace py = pybind11;
@ -435,7 +438,10 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
old_output = NewValueNode(kNone);
}
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
AnfNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
CNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
<< ", state: " << state->DebugString(2);
func_graph()->set_output(depend_node, true);

View File

@ -397,6 +397,7 @@ constexpr auto kAttrRecompute = "recompute";
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrStitch = "stitch";
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -32,6 +32,7 @@
#include "ir/func_graph.h"
#include "utils/log_adapter.h"
#include "utils/ms_context.h"
#include "mindspore/ccsrc/utils/utils.h"
namespace mindspore {
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
@ -170,8 +171,16 @@ std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
auto &inputs = cnode->inputs();
vecs->reserve(vecs->size() + inputs.size());
// To keep evaluate order from left to right, we push inputs in reversed order.
vecs->insert(vecs->end(), inputs.rbegin(), inputs.rend());
// To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
auto sort_rhs_first =
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
if (sort_rhs_first) {
vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
} else {
vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
}
}
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {

View File

@ -174,8 +174,8 @@ def test_dot_008():
network = NetDot()
try:
network(x2_tensor, x1_tensor)
except IndexError as e:
assert IndexError == type(e)
except ValueError as e:
assert ValueError == type(e)
@pytest.mark.level0

View File

@ -82,7 +82,7 @@ def test_check_multifield_embedding_false_type_field_id():
@non_graph_engine
def test_check_multifield_embedding_false_input_shape():
with pytest.raises(IndexError):
with pytest.raises(ValueError):
compile_multi_field_embedding((8,), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16)

View File

@ -84,7 +84,7 @@ def test_ssim_different_shape():
img1 = Tensor(np.random.random(shape_1))
img2 = Tensor(np.random.random(shape_2))
net = SSIMNet()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, img1, img2)
@ -108,9 +108,9 @@ def test_ssim_invalid_5d_input():
invalid_img2 = Tensor(np.random.random(invalid_shape))
net = SSIMNet()
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, invalid_img1, img2)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, img1, invalid_img2)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_executor.compile(net, invalid_img1, invalid_img2)