forked from mindspore-Ecosystem/mindspore
adjust cse code when op has side effect.
This commit is contained in:
parent
af2ebca023
commit
ef13a4b6fb
|
@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
|||
|
||||
return changed;
|
||||
}
|
||||
|
||||
// The op like print, summary, or the op do not has true output, and always as a depend node input.
|
||||
static bool HasSideEffect(const AnfNodePtr &node) {
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
|
||||
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
|
||||
return GetValue<bool>(side_effect_v);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// If true do not merge the node.
|
||||
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
bool has_random_effect = false;
|
||||
auto prim_main = GetCNodePrimitive(main);
|
||||
auto prim_node = GetCNodePrimitive(node);
|
||||
// if has random effect, when generate by different op (not same object), do not merge.
|
||||
if (prim_main != nullptr) {
|
||||
if (prim_main == prim_node) {
|
||||
return false;
|
||||
}
|
||||
if (prim_main != nullptr) {
|
||||
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
|
||||
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
|
||||
has_random_effect = GetValue<bool>(effect_val);
|
||||
|
@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
|
|||
return has_random_effect;
|
||||
}
|
||||
|
||||
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
|
||||
MS_EXCEPTION_IF_NULL(main);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
bool replace = false;
|
||||
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
|
||||
auto main_value = GetValueNode(main);
|
||||
auto node_value = GetValueNode(node);
|
||||
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
||||
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
|
||||
} else if (main->isa<CNode>() && node->isa<CNode>()) {
|
||||
auto c_main = main->cast<CNodePtr>();
|
||||
auto c_node = node->cast<CNodePtr>();
|
||||
// When appsame is true, check if has side effect, do not merge.
|
||||
if (check_side_effect && HasSideEffect(main)) {
|
||||
return false;
|
||||
}
|
||||
const auto &inp1 = c_main->inputs();
|
||||
const auto &inp2 = c_node->inputs();
|
||||
if (inp1.size() == inp2.size()) {
|
||||
bool appsame = true;
|
||||
if (inp1.size() != inp2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t j = 0; j < inp1.size(); j++) {
|
||||
MS_EXCEPTION_IF_NULL(inp1[j]);
|
||||
MS_EXCEPTION_IF_NULL(inp2[j]);
|
||||
if (!(*inp1[j] == *inp2[j])) {
|
||||
auto inp1_j = inp1[j];
|
||||
auto inp2_j = inp2[j];
|
||||
MS_EXCEPTION_IF_NULL(inp1_j);
|
||||
MS_EXCEPTION_IF_NULL(inp2_j);
|
||||
if (!(*inp1_j == *inp2_j)) {
|
||||
// Handle the case of two different Tensor, but with the same value
|
||||
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) {
|
||||
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]);
|
||||
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]);
|
||||
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
|
||||
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
|
||||
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
|
||||
if (tensor1->ValueEqual(*tensor2)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
appsame = false;
|
||||
break;
|
||||
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
|
||||
// When the same side effect node as another two nodes' inputs, we still merge the node.
|
||||
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
|
||||
// node.
|
||||
if (CheckReplace(inp1_j, inp2_j, false)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// When appsame is true, check if has random effect do not merge
|
||||
if (CheckRandomEffect(c_main, c_node)) {
|
||||
appsame = false;
|
||||
return false;
|
||||
}
|
||||
replace = appsame;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return replace;
|
||||
// a parameter node.
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,
|
||||
|
|
|
@ -41,7 +41,7 @@ class CSE {
|
|||
return chg && report_changes_;
|
||||
}
|
||||
|
||||
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const;
|
||||
|
||||
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
|
||||
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
|
||||
MS_EXCEPTION_IF_NULL(main);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class BackendCSE : public CSE {
|
|||
public:
|
||||
BackendCSE() = default;
|
||||
~BackendCSE() override = default;
|
||||
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override;
|
||||
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
|
|||
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
|
||||
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
||||
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
||||
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
|
|||
extern const char GRAPH_FLAG_HAS_EFFECT[];
|
||||
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
||||
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
||||
|
||||
extern const char GRAPH_FLAG_SIDE_EFFECT[];
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // PYBIND_API_EXPORT_FLAGS_H_
|
||||
|
|
|
@ -220,7 +220,9 @@ class DataWrapper(Cell):
|
|||
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
||||
|
|
|
@ -334,7 +334,7 @@ class Print(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
self.add_prim_attr("_side_effect", True)
|
||||
|
||||
def __call__(self, *args):
|
||||
for arg in args:
|
||||
|
|
|
@ -336,6 +336,7 @@ class PrintNet(nn.Cell):
|
|||
def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
|
||||
scale1, scale2):
|
||||
self.print('============tensor int8:==============', int8)
|
||||
self.print('============tensor int8:==============', int8)
|
||||
self.print('============tensor uint8:==============', uint8)
|
||||
self.print('============tensor int16:==============', int16)
|
||||
self.print('============tensor uint16:==============', uint16)
|
||||
|
|
Loading…
Reference in New Issue