forked from mindspore-Ecosystem/mindspore
!28553 [Fallback] Deal with constants such as np.pi
Merge pull request !28553 from huangbingjian/attr_const
This commit is contained in:
commit
d7c41f3f5f
|
@ -896,7 +896,18 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
// Create the apply node
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
UpdateInterpretForUserNode(attr_cnode, value_node);
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback) {
|
||||
// Check whether it is constant, constant does not need interpret.
|
||||
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
|
||||
py::bool_ is_const_value =
|
||||
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
|
||||
auto is_constant = py::cast<bool>(is_const_value);
|
||||
if (!is_constant) {
|
||||
UpdateInterpretForUserNode(attr_cnode, value_node);
|
||||
}
|
||||
}
|
||||
return attr_cnode;
|
||||
}
|
||||
|
||||
|
@ -1994,6 +2005,12 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
}
|
||||
|
||||
void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node) {
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (!use_fallback) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Do not handle user node with internal type such as Tensor.abs().
|
||||
|
|
|
@ -87,6 +87,7 @@ const char PYTHON_PARSE_GET_OPERATION_SYMBOL[] = "get_operation_symbol";
|
|||
const char PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL[] = "get_operation_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL[] = "get_builtin_namespace_symbol";
|
||||
const char PYTHON_PARSE_GET_LOCATION[] = "get_location";
|
||||
const char PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE[] = "is_constant_value";
|
||||
const char PYTHON_PARSE_EXPAND_EXPR_STATEMENT[] = "expand_expr_statement";
|
||||
const char PYTHON_PARSE_GENERATE_SCOPE[] = "generate_scope";
|
||||
const char PYTHON_PARSE_GET_SCOPE_NAME[] = "get_scope_name";
|
||||
|
|
|
@ -639,6 +639,15 @@ class Parser:
|
|||
logger.error("Fn type is invalid")
|
||||
return None, None
|
||||
|
||||
def is_constant_value(self, var, attr):
|
||||
if var in self.global_namespace:
|
||||
module = self.global_namespace[var]
|
||||
if hasattr(module, attr):
|
||||
value = getattr(module, attr)
|
||||
# Check if value is constant.
|
||||
return isinstance(value, (int, float, bool))
|
||||
return False
|
||||
|
||||
def is_unsupported_namespace(self, value):
|
||||
unsupported = isinstance(value, _builtin_function_or_method_type) and value not in convert_object_map
|
||||
logger.debug(f"'{value}' unsupported: {unsupported}.")
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor, ms_function, context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.probability import distribution
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.common._monad as monad
|
||||
|
||||
|
@ -334,3 +335,34 @@ def test_self_method_2():
|
|||
net = Network()
|
||||
out = net()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_probability_cauchy():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: NumPy method is called in probability cauchy.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class CauchyProb(nn.Cell):
|
||||
def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
|
||||
super().__init__()
|
||||
self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
|
||||
|
||||
def construct(self, value, loc=None, scale=None):
|
||||
out1 = self.b.prob(value, loc, scale)
|
||||
out2 = self.b.log_prob(value, loc, scale)
|
||||
out3 = self.b.cdf(value, loc, scale)
|
||||
out4 = self.b.log_cdf(value, loc, scale)
|
||||
out5 = self.b.survival_function(value, loc, scale)
|
||||
out6 = self.b.log_survival(value, loc, scale)
|
||||
return out1, out2, out3, out4, out5, out6
|
||||
|
||||
|
||||
loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
|
||||
scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
|
||||
loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
|
||||
scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
|
||||
value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
|
||||
|
||||
net = CauchyProb(loc, scale)
|
||||
net(Tensor(value), Tensor(loc_a), Tensor(scale_a))
|
||||
|
|
Loading…
Reference in New Issue