forked from mindspore-Ecosystem/mindspore
parent
dc961e461e
commit
d3367dcc77
|
@ -21,6 +21,52 @@ namespace opt {
|
|||
namespace irpass {
|
||||
#define UPPER_FLT_LIMIT (FLT_MAX / 2.0)
|
||||
#define LOWER_FLT_LIMIT (-FLT_MAX / 2.0)
|
||||
// Define the checking mode
|
||||
enum ScalarCheckingMode : int { GREATER_EQUAL = 0, LESS };
|
||||
|
||||
bool IsNodeScalarTrueWith(const AnfNodePtr &node, const ScalarCheckingMode &checking_mode, const float &check_value) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto scalar = value->cast<ScalarPtr>();
|
||||
if (scalar != nullptr) {
|
||||
if (scalar->isa<FloatImm>()) {
|
||||
if (checking_mode == GREATER_EQUAL) {
|
||||
return GetValue<float>(scalar) >= check_value;
|
||||
}
|
||||
return GetValue<float>(scalar) < check_value;
|
||||
}
|
||||
}
|
||||
// Check for Tensor [] or Tensor [1]
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
if (tensor_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (tensor_ptr->DataSize() > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
if (checking_mode == GREATER_EQUAL) {
|
||||
return data[0] >= check_value;
|
||||
}
|
||||
return data[0] < check_value;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if a value is greater or equal 0.0
|
||||
bool IsNodeScalarPositive(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, 0.0); }
|
||||
|
||||
bool IsCNodePositive(const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimSqueeze)) {
|
||||
|
@ -29,80 +75,22 @@ bool IsCNodePositive(const AnfNodePtr &node) {
|
|||
if (IsPrimitiveCNode(node, prim::kPrimSquare) || IsPrimitiveCNode(node, prim::kPrimSqrt)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if a value is bigger than UPPER_FLT_LIMIT
|
||||
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto scalar = value->cast<ScalarPtr>();
|
||||
if (scalar != nullptr) {
|
||||
if (scalar->isa<FloatImm>()) {
|
||||
return GetValue<float>(scalar) > UPPER_FLT_LIMIT;
|
||||
}
|
||||
}
|
||||
// Check for Tensor [] or Tensor [1]
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
if (tensor_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (tensor_ptr->DataSize() > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
return data[0] > UPPER_FLT_LIMIT;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMinimum) || IsPrimitiveCNode(node, prim::kPrimRealDiv)) {
|
||||
auto first_node_positive =
|
||||
IsCNodePositive(node->cast<CNodePtr>()->input(1)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(1));
|
||||
auto second_node_positive =
|
||||
IsCNodePositive(node->cast<CNodePtr>()->input(2)) || IsNodeScalarPositive(node->cast<CNodePtr>()->input(2));
|
||||
return first_node_positive && second_node_positive;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// check if a value is greater or equal UPPER_FLT_LIMIT
|
||||
bool IsNodeScalarMaxFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, GREATER_EQUAL, UPPER_FLT_LIMIT); }
|
||||
|
||||
// check if a value is smaller than LOWER_FLT_LIMIT
|
||||
bool IsNodeScalarMinFLT(const AnfNodePtr &node) {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value = value_node->value();
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto scalar = value->cast<ScalarPtr>();
|
||||
if (scalar != nullptr) {
|
||||
if (scalar->isa<FloatImm>()) {
|
||||
return GetValue<float>(scalar) < LOWER_FLT_LIMIT;
|
||||
}
|
||||
}
|
||||
// Check for Tensor [] or Tensor [1]
|
||||
auto tensor_ptr = value->cast<tensor::TensorPtr>();
|
||||
if (tensor_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (tensor_ptr->DataSize() > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TypeId tensor_type = tensor_ptr->Dtype()->type_id();
|
||||
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
|
||||
float *data = reinterpret_cast<float *>(tensor_ptr->data_c());
|
||||
return data[0] < LOWER_FLT_LIMIT;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
bool IsNodeScalarMinFLT(const AnfNodePtr &node) { return IsNodeScalarTrueWith(node, LESS, LOWER_FLT_LIMIT); }
|
||||
|
||||
AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
|
||||
PatternNode x, y, z;
|
||||
|
@ -116,10 +104,15 @@ AnfNodePtr ValueBasedEliminate::operator()(const OptimizerPtr &, const AnfNodePt
|
|||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSelect, PPrimitive(prim::kPrimGreater, x, zero_scalar_), y, z), y,
|
||||
IsCNodePositive(x.GetNode(node)));
|
||||
|
||||
// {prim::kPrimMaximum, X, LOWER_FLT_LIMIT}} -> X
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, y), x, IsNodeScalarMinFLT(y.GetNode(node)));
|
||||
|
||||
// {prim::kPrimMinimum, X, UPPER_FLT_LIMIT}} -> X
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMinimum, x, y), x, IsNodeScalarMaxFLT(y.GetNode(node)));
|
||||
|
||||
// {prim::kPrimMaximum, X, 0}} -> X when X is always greater or equal 0
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimMaximum, x, zero_), x, IsCNodePositive(x.GetNode(node)));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue