!4198 New optimization pass to remove redundant Max ops

Merge pull request !4198 from thlinh/dev_Aug08_eliminate_more_redundant_Max
This commit is contained in:
mindspore-ci-bot 2020-08-17 12:48:36 +08:00 committed by Gitee
commit 1b28f77be4
1 changed files with 61 additions and 68 deletions

View File

@ -21,6 +21,52 @@ namespace opt {
namespace irpass {
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)
// Define the checking mode
enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS };
bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}
auto value = value_node->value();
if (value == nullptr) {
return false;
}
auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
if (checking_mode == GREATER_EQUAL) {
return GetValue<float>(scalar) >= check_value;
}
return GetValue<float>(scalar) < check_value;
}
}
// Check for Tensor [] or Tensor [1]
auto tensor_ptr = value->cast<tensor::TensorPtr>();
if (tensor_ptr == nullptr) {
return false;
}
if (tensor_ptr->DataSize() > 1) {
return false;
}
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
if (checking_mode == GREATER_EQUAL) {
return data[0] >= check_value;
}
return data[0] < check_value;
}
return false;
}
// check if a value is greater or equal 0.0
bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); }
bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
@ -29,80 +75,22 @@ bool IsCNodePositive(const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) {
return true;
}
return false;
}
// check if a value is bigger than UPPER_FLT_LIMIT
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}
auto value = value_node->value();
if (value == nullptr) {
return false;
}
auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
return GetValue<float>(scalar) > UPPER_FLT_LIMIT;
}
}
// Check for Tensor [] or Tensor [1]
auto tensor_ptr = value->cast<tensor::TensorPtr>();
if (tensor_ptr == nullptr) {
return false;
}
if (tensor_ptr->DataSize() > 1) {
return false;
}
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
return data[0] > UPPER_FLT_LIMIT;
if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
auto first_node_positive =
IsCNodePositive(node->cast<CNodePtr>()->input(1)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(1));
auto second_node_positive =
IsCNodePositive(node->cast<CNodePtr>()->input(2)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(2));
return first_node_positive && second_node_positive;
}
return false;
}
// check if a value is greater or equal UPPER_FLT_LIMIT
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); }
// check if a value is smaller than LOWER_FLT_LIMIT
bool IsNodeScalarMinFLT(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}
auto value = value_node->value();
if (value == nullptr) {
return false;
}
auto scalar = value->cast<ScalarPtr>();
if (scalar != nullptr) {
if (scalar->isa<FloatImm>()) {
return GetValue<float>(scalar) < LOWER_FLT_LIMIT;
}
}
// Check for Tensor [] or Tensor [1]
auto tensor_ptr = value->cast<tensor::TensorPtr>();
if (tensor_ptr == nullptr) {
return false;
}
if (tensor_ptr->DataSize() > 1) {
return false;
}
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
return data[0] < LOWER_FLT_LIMIT;
}
return false;
}
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); }
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
PatternNode x, y, z;
@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y,
IsCNodePositive(x.GetNode(node)));
// {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node)));
// {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node)));
// {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node)));
return nullptr;
}
} // namespace irpass