!28553 [Fallback] Deal with constants such as np.pi

Merge pull request !28553 from huangbingjian/attr_const
This commit is contained in:
i-robot 2022-01-06 07:13:24 +00:00 committed by Gitee
commit d7c41f3f5f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 60 additions and 1 deletions

View File

@ -896,7 +896,18 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
MS_EXCEPTION_IF_NULL(block->func_graph()); MS_EXCEPTION_IF_NULL(block->func_graph());
// Create the apply node // Create the apply node
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node}); auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback) {
// Check whether it is constant, constant does not need interpret.
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
py::bool_ is_const_value =
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
auto is_constant = py::cast<bool>(is_const_value);
if (!is_constant) {
UpdateInterpretForUserNode(attr_cnode, value_node); UpdateInterpretForUserNode(attr_cnode, value_node);
}
}
return attr_cnode; return attr_cnode;
} }
@ -1994,6 +2005,12 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
} }
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) { void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (!use_fallback) {
return;
}
MS_EXCEPTION_IF_NULL(user_node); MS_EXCEPTION_IF_NULL(user_node);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
// Do not handle user node with internal type such as Tensor.abs(). // Do not handle user node with internal type such as Tensor.abs().

View File

@ -87,6 +87,7 @@ const char PYTHON_PARSE_GET_OPERATION_SYMBOL[] = "get_operation_symbol";
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol"; const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol"; const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
const char PYTHON_PARSE_GET_LOCATION[] = "get_location"; const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
const char PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE[] = "is_constant_value";
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement"; const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope"; const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name"; const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";

View File

@ -639,6 +639,15 @@ class Parser:
logger.error("Fn type is invalid") logger.error("Fn type is invalid")
return None, None return None, None
def is_constant_value(self, var, attr):
if var in self.global_namespace:
module = self.global_namespace[var]
if hasattr(module, attr):
value = getattr(module, attr)
# Check if value is constant.
return isinstance(value, (int, float, bool))
return False
def is_unsupported_namespace(self, value): def is_unsupported_namespace(self, value):
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
logger.debug(f"'{value}' unsupported: {unsupported}.") logger.debug(f"'{value}' unsupported: {unsupported}.")

View File

@ -20,6 +20,7 @@ import mindspore.nn as nn
from mindspore import Tensor, ms_function, context from mindspore import Tensor, ms_function, context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.nn.probability import distribution
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.common._monad as monad import mindspore.common._monad as monad
@ -334,3 +335,34 @@ def test_self_method_2():
net = Network() net = Network()
out = net() out = net()
print(out) print(out)
def test_probability_cauchy():
"""
Feature: JIT Fallback
Description: NumPy method is called in probability cauchy.
Expectation: No exception.
"""
class CauchyProb(nn.Cell):
def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
super().__init__()
self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
def construct(self, value, loc=None, scale=None):
out1 = self.b.prob(value, loc, scale)
out2 = self.b.log_prob(value, loc, scale)
out3 = self.b.cdf(value, loc, scale)
out4 = self.b.log_cdf(value, loc, scale)
out5 = self.b.survival_function(value, loc, scale)
out6 = self.b.log_survival(value, loc, scale)
return out1, out2, out3, out4, out5, out6
loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
net = CauchyProb(loc, scale)
net(Tensor(value), Tensor(loc_a), Tensor(scale_a))