diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc index 38b59afe96b..c0e4158033a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/value_based_eliminate.cc @@ -21,6 +21,52 @@ namespace opt { namespace irpass { #define UPPER_FLT_LIMIT (FLT_MAX / 2.0) #define LOWER_FLT_LIMIT (-FLT_MAX / 2.0) +// Define the checking mode +enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS }; + +bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + auto value = value_node->value(); + if (value == nullptr) { + return false; + } + + auto scalar = value->cast(); + if (scalar != nullptr) { + if (scalar->isa()) { + if (checking_mode == GREATER_EQUAL) { + return GetValue(scalar) >= check_value; + } + return GetValue(scalar) < check_value; + } + } + // Check for Tensor [] or Tensor [1] + auto tensor_ptr = value->cast(); + if (tensor_ptr == nullptr) { + return false; + } + if (tensor_ptr->DataSize() > 1) { + return false; + } + + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data = reinterpret_cast(tensor_ptr->data_c()); + if (checking_mode == GREATER_EQUAL) { + return data[0] >= check_value; + } + return data[0] < check_value; + } + + return false; +} + +// check if a value is greater or equal 0.0 +bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); } bool IsCNodePositive(const AnfNodePtr &node) { if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) { @@ -29,80 +75,22 @@ bool IsCNodePositive(const AnfNodePtr &node) { if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) { return true; } - return false; -} - -// check if a value is bigger than UPPER_FLT_LIMIT -bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - auto value = value_node->value(); - if (value == nullptr) { - return false; - } - - auto scalar = value->cast(); - if (scalar != nullptr) { - if (scalar->isa()) { - return GetValue(scalar) > UPPER_FLT_LIMIT; - } - } - // Check for Tensor [] or Tensor [1] - auto tensor_ptr = value->cast(); - if (tensor_ptr == nullptr) { - return false; - } - if (tensor_ptr->DataSize() > 1) { - return false; - } - - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data = reinterpret_cast(tensor_ptr->data_c()); - return data[0] > UPPER_FLT_LIMIT; + if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) { + auto first_node_positive = + IsCNodePositive(node->cast()->input(1)) || IsNodeScalarPositive(node->cast()->input(1)); + auto second_node_positive = + IsCNodePositive(node->cast()->input(2)) || IsNodeScalarPositive(node->cast()->input(2)); + return first_node_positive && second_node_positive; } return false; } +// check if a value is greater or equal UPPER_FLT_LIMIT +bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); } + // check if a value is smaller than LOWER_FLT_LIMIT -bool IsNodeScalarMinFLT(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - auto value = value_node->value(); - if (value == nullptr) { - return false; - } - - auto scalar = value->cast(); - if (scalar != nullptr) { - if (scalar->isa()) { - return GetValue(scalar) < LOWER_FLT_LIMIT; - } - } - // Check for Tensor [] or Tensor [1] - auto tensor_ptr = value->cast(); - if (tensor_ptr == nullptr) { - return false; - } - if (tensor_ptr->DataSize() > 1) { - return false; - } - - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data = reinterpret_cast(tensor_ptr->data_c()); - return data[0] < LOWER_FLT_LIMIT; - } - - return false; -} +bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); } AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { PatternNode x, y, z; @@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y, IsCNodePositive(x.GetNode(node))); + // {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node))); + // {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node))); + // {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0 + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node))); + return nullptr; }