forked from mindspore-Ecosystem/mindspore
!12949 Add TopoSort Rhs First attribute for special CNode, such as Depend CNode with isolated nodes.
From: @zh_qh Reviewed-by: @hwhewei,@zhunaipan Signed-off-by: @zhunaipan
This commit is contained in:
commit
48d4cca512
|
@ -228,13 +228,13 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons
|
|||
bool change = false;
|
||||
auto res = DoTransform(optimizer, node, substitution);
|
||||
if (res != nullptr) {
|
||||
if (is_once_) {
|
||||
return true;
|
||||
}
|
||||
change = true;
|
||||
changes = true;
|
||||
node = res;
|
||||
}
|
||||
if (change && is_once_) {
|
||||
return true;
|
||||
}
|
||||
UpdateTransformingList(optimizer, node, &todo, change, seen);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,14 +17,17 @@
|
|||
*/
|
||||
|
||||
#include "pipeline/jit/parse/function_block.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/info.h"
|
||||
#include "debug/trace.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
@ -435,7 +438,10 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
|
|||
old_output = NewValueNode(kNone);
|
||||
}
|
||||
AnfNodePtr stop_grad_node = func_graph()->NewCNode({NewValueNode(prim::kPrimStopGradient), state});
|
||||
AnfNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
|
||||
CNodePtr depend_node = func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), old_output, stop_grad_node});
|
||||
// We add this attribute for @constexpr use scene, since we must infer them before other nodes.
|
||||
// That means isolated nodes will be evaluated first. It's not complete, but works in most scenes.
|
||||
depend_node->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
|
||||
MS_LOG(INFO) << "Attached for side-effect nodes, depend_node: " << depend_node->DebugString()
|
||||
<< ", state: " << state->DebugString(2);
|
||||
func_graph()->set_output(depend_node, true);
|
||||
|
|
|
@ -397,6 +397,7 @@ constexpr auto kAttrRecompute = "recompute";
|
|||
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
||||
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
|
||||
constexpr auto kAttrStitch = "stitch";
|
||||
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "ir/func_graph.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "mindspore/ccsrc/utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
|
||||
|
@ -170,8 +171,16 @@ std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(FuncGraphPtr root) {
|
|||
static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
|
||||
auto &inputs = cnode->inputs();
|
||||
vecs->reserve(vecs->size() + inputs.size());
|
||||
// To keep evaluate order from left to right, we push inputs in reversed order.
|
||||
vecs->insert(vecs->end(), inputs.rbegin(), inputs.rend());
|
||||
|
||||
// To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
|
||||
auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
|
||||
auto sort_rhs_first =
|
||||
attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
|
||||
if (sort_rhs_first) {
|
||||
vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
|
||||
} else {
|
||||
vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
||||
|
|
|
@ -174,8 +174,8 @@ def test_dot_008():
|
|||
network = NetDot()
|
||||
try:
|
||||
network(x2_tensor, x1_tensor)
|
||||
except IndexError as e:
|
||||
assert IndexError == type(e)
|
||||
except ValueError as e:
|
||||
assert ValueError == type(e)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -82,7 +82,7 @@ def test_check_multifield_embedding_false_type_field_id():
|
|||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_input_shape():
|
||||
with pytest.raises(IndexError):
|
||||
with pytest.raises(ValueError):
|
||||
compile_multi_field_embedding((8,), (8, 200), (8, 200),
|
||||
dtype.int16, dtype.float32, dtype.int16)
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ def test_ssim_different_shape():
|
|||
img1 = Tensor(np.random.random(shape_1))
|
||||
img2 = Tensor(np.random.random(shape_2))
|
||||
net = SSIMNet()
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, img1, img2)
|
||||
|
||||
|
||||
|
@ -108,9 +108,9 @@ def test_ssim_invalid_5d_input():
|
|||
invalid_img2 = Tensor(np.random.random(invalid_shape))
|
||||
|
||||
net = SSIMNet()
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, invalid_img1, img2)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, img1, invalid_img2)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
_executor.compile(net, invalid_img1, invalid_img2)
|
||||
|
|
Loading…
Reference in New Issue