adjust cse code when op has side effect.

This commit is contained in:
Wei Luning 2020-06-25 00:45:10 +08:00
parent af2ebca023
commit ef13a4b6fb
9 changed files with 66 additions and 36 deletions

View File

@ -89,15 +89,28 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
return changed;
}
// The op like print, summary, or the op do not has true output, and always as a depend node input.
static bool HasSideEffect(const AnfNodePtr &node) {
auto prim = GetCNodePrimitive(node);
if (prim == nullptr) {
return false;
}
auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT);
if (side_effect_v != nullptr && side_effect_v->isa<BoolImm>()) {
return GetValue<bool>(side_effect_v);
}
return false;
}
// If true do not merge the node.
bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool has_random_effect = false;
auto prim_main = GetCNodePrimitive(main);
auto prim_node = GetCNodePrimitive(node);
if (prim_main == prim_node) {
return false;
}
// if has random effect, when generate by different op (not same object), do not merge.
if (prim_main != nullptr) {
if (prim_main == prim_node) {
return false;
}
auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT);
if (effect_val != nullptr && effect_val->isa<BoolImm>()) {
has_random_effect = GetValue<bool>(effect_val);
@ -106,45 +119,58 @@ bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) cons
return has_random_effect;
}
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);
bool replace = false;
if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
auto main_value = GetValueNode(main);
auto node_value = GetValueNode(node);
replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value);
} else if (main->isa<CNode>() && node->isa<CNode>()) {
auto c_main = main->cast<CNodePtr>();
auto c_node = node->cast<CNodePtr>();
// When appsame is true, check if has side effect, do not merge.
if (check_side_effect && HasSideEffect(main)) {
return false;
}
const auto &inp1 = c_main->inputs();
const auto &inp2 = c_node->inputs();
if (inp1.size() == inp2.size()) {
bool appsame = true;
for (size_t j = 0; j < inp1.size(); j++) {
MS_EXCEPTION_IF_NULL(inp1[j]);
MS_EXCEPTION_IF_NULL(inp2[j]);
if (!(*inp1[j] == *inp2[j])) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1[j]) && IsValueNode<tensor::Tensor>(inp2[j])) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1[j]);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2[j]);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
}
appsame = false;
break;
}
}
if (CheckRandomEffect(c_main, c_node)) {
appsame = false;
}
replace = appsame;
if (inp1.size() != inp2.size()) {
return false;
}
for (size_t j = 0; j < inp1.size(); j++) {
auto inp1_j = inp1[j];
auto inp2_j = inp2[j];
MS_EXCEPTION_IF_NULL(inp1_j);
MS_EXCEPTION_IF_NULL(inp2_j);
if (!(*inp1_j == *inp2_j)) {
// Handle the case of two different Tensor, but with the same value
if (IsValueNode<tensor::Tensor>(inp1_j) && IsValueNode<tensor::Tensor>(inp2_j)) {
auto tensor1 = GetValueNode<tensor::TensorPtr>(inp1_j);
auto tensor2 = GetValueNode<tensor::TensorPtr>(inp2_j);
if (tensor1->ValueEqual(*tensor2)) {
continue;
}
} else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) {
// When the same side effect node as another two nodes' inputs, we still merge the node.
// Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the
// node.
if (CheckReplace(inp1_j, inp2_j, false)) {
continue;
}
}
return false;
}
}
// When appsame is true, check if has random effect do not merge
if (CheckRandomEffect(c_main, c_node)) {
return false;
}
return true;
}
return replace;
// a parameter node.
return false;
}
bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector<std::size_t> &order_group,

View File

@ -41,7 +41,7 @@ class CSE {
return chg && report_changes_;
}
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const;
virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const;
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;

View File

@ -35,7 +35,7 @@ bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) {
}
} // namespace
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const {
bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const {
MS_EXCEPTION_IF_NULL(main);
MS_EXCEPTION_IF_NULL(node);

View File

@ -31,7 +31,7 @@ class BackendCSE : public CSE {
public:
BackendCSE() = default;
~BackendCSE() override = default;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) const override;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
};
} // namespace opt
} // namespace mindspore

View File

@ -33,5 +33,6 @@ const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";
} // namespace mindspore

View File

@ -34,7 +34,7 @@ extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];
} // namespace mindspore
#endif // PYBIND_API_EXPORT_FLAGS_H_

View File

@ -220,7 +220,9 @@ class DataWrapper(Cell):
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network

View File

@ -334,7 +334,7 @@ class Print(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
pass
self.add_prim_attr("_side_effect", True)
def __call__(self, *args):
for arg in args:

View File

@ -336,6 +336,7 @@ class PrintNet(nn.Cell):
def construct(self, int8, uint8, int16, uint16, int32, uint32, int64, uint64, flt16, flt32, flt64, bool_,
scale1, scale2):
self.print('============tensor int8:==============', int8)
self.print('============tensor int8:==============', int8)
self.print('============tensor uint8:==============', uint8)
self.print('============tensor int16:==============', int16)
self.print('============tensor uint16:==============', uint16)