forked from mindspore-Ecosystem/mindspore
* add isconstant primitive
* add infer_value for common math ops * convert constant bool tensor to bool value * do not infer value when encounter 0 as division for while condition, do not unrool if condition is a tensor
This commit is contained in:
parent
fac6c56db5
commit
beacc26077
|
@ -126,7 +126,7 @@ convert_object_map = {
|
|||
T.make_list: F.make_list,
|
||||
T.make_slice: F.make_slice,
|
||||
T.range: F.make_range,
|
||||
|
||||
T.while_cond: M.while_cond,
|
||||
# lib function
|
||||
math.floor: NO_IMPLEMENT,
|
||||
math.trunc: NO_IMPLEMENT,
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
# ============================================================================
|
||||
"""standard_method"""
|
||||
from dataclasses import dataclass
|
||||
from mindspore.common import dtype as mstype
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.primitive import constexpr
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like
|
||||
from ...ops.composite.base import _append
|
||||
|
@ -102,11 +104,44 @@ def bool_(x):
|
|||
return x.__bool__()
|
||||
|
||||
|
||||
def tensor_bool(x):
|
||||
"""return immedate x, x is a tensor of bool value"""
|
||||
def while_cond(x):
|
||||
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
||||
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond:
|
||||
return F.cast(x, mstype.bool_)
|
||||
return x
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tensor_bool_cond(shp):
|
||||
"""check if tensor is a bool condition"""
|
||||
if shp in ((), (1,)):
|
||||
return True
|
||||
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
|
||||
|
||||
@constexpr
|
||||
def const_tensor_to_bool(x):
|
||||
"""convert bool tensor to bool condition"""
|
||||
if x is None:
|
||||
raise ValueError("Only constant tensor bool can be converted to bool")
|
||||
x = x.asnumpy()
|
||||
if x.shape not in ((), (1,)):
|
||||
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
|
||||
if x.shape == ():
|
||||
value = bool(x)
|
||||
else:
|
||||
value = bool(x[0])
|
||||
return value
|
||||
|
||||
def tensor_bool(x):
|
||||
"""tensor as conditon, if is constant, return immediate bool value"""
|
||||
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
||||
if is_cond and F.isconstant(x):
|
||||
return const_tensor_to_bool(x)
|
||||
return F.cast(x, mstype.bool_)
|
||||
|
||||
|
||||
def and_(x, y):
|
||||
"""Implementation of `and` (`&`)."""
|
||||
return x.__and__(y)
|
||||
|
|
|
@ -91,3 +91,7 @@ def to_array(x): # pragma: no cover
|
|||
def not_contains(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
||||
def while_cond(x): # pragma: no cover
|
||||
"""Not in function."""
|
||||
raise RuntimeError('This operation is not meant to be called directly.')
|
||||
|
|
|
@ -281,6 +281,16 @@ class TraceForceBool : public TraceInfo {
|
|||
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
|
||||
};
|
||||
|
||||
class TraceForceWhileCond : public TraceInfo {
|
||||
public:
|
||||
explicit TraceForceWhileCond(const DebugInfoPtr &info) : TraceInfo(info, "force_while_cond", "") {}
|
||||
MS_DECLARE_PARENT(TraceForceWhileCond, TraceInfo);
|
||||
~TraceForceWhileCond() override = default;
|
||||
TraceInfoPtr clone() override {
|
||||
return std::make_shared<TraceForceWhileCond>(*shared_from_base<TraceForceWhileCond>());
|
||||
}
|
||||
};
|
||||
|
||||
class TraceExpandJ : public TraceInfo {
|
||||
public:
|
||||
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
|
||||
|
|
|
@ -243,6 +243,7 @@ const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
|||
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
||||
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
|
||||
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
|
||||
// Comm ops
|
||||
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
|
|
|
@ -252,6 +252,7 @@ extern const PrimitivePtr kPrimIsNot;
|
|||
extern const PrimitivePtr kPrimInDict;
|
||||
extern const PrimitivePtr kPrimNotInDict;
|
||||
extern const PrimitivePtr kPrimMixedPrecisionCast;
|
||||
extern const PrimitivePtr kPrimIsConsant;
|
||||
|
||||
// Comm ops
|
||||
extern const PrimitivePtr kPrimMirror;
|
||||
|
|
|
@ -110,7 +110,8 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
|
||||
ValuePtr v = cond->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
if (v->isa<AnyValue>()) {
|
||||
// for tensor as condition, keeps both true and false branch.
|
||||
if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
|
||||
MS_EXCEPTION_IF_NULL(tb);
|
||||
return tb->Join(fb);
|
||||
}
|
||||
|
@ -228,5 +229,16 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
// Inputs: x, t
|
||||
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
|
||||
}
|
||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// statement: isconstant(x)
|
||||
// Inputs: x
|
||||
if (args_spec_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
|
||||
}
|
||||
ValuePtr v = args_spec_list[0]->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
|
||||
}
|
||||
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -265,6 +265,13 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
|
|||
return op_apply_node;
|
||||
}
|
||||
|
||||
CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
|
||||
TraceManager::DebugTrace(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
|
||||
CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond});
|
||||
TraceManager::EndTrace();
|
||||
return op_apply_node;
|
||||
}
|
||||
|
||||
// Perform a jump from this block to target block
|
||||
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) {
|
||||
if (func_graph()->get_return() != nullptr) {
|
||||
|
|
|
@ -55,6 +55,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
|
|||
// A block is matured if all its predecessors is generated
|
||||
void Mature();
|
||||
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
|
||||
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
|
||||
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
|
||||
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
|
||||
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock);
|
||||
|
|
|
@ -967,6 +967,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
|
||||
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
|
||||
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
|
||||
condition_node = header_block->ForceToWhileCond(condition_node);
|
||||
body_block->Mature();
|
||||
header_block->ConditionalJump(condition_node, body_block, after_block);
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimIsNot, {InferImplIsNot, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, true}},
|
||||
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
|
||||
// Maths
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
|
|
|
@ -200,6 +200,8 @@ AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -26,6 +26,7 @@ typeof = Primitive('typeof')
|
|||
hastype = Primitive('hastype')
|
||||
cast = P.Cast()
|
||||
dtype = P.DType()
|
||||
isconstant = Primitive('is_constant')
|
||||
|
||||
|
||||
issubclass_ = P.IsSubClass()
|
||||
|
|
|
@ -2294,7 +2294,7 @@ class Abs(PrimitiveWithInfer):
|
|||
def infer_value(self, x):
|
||||
if x is not None:
|
||||
x = x.asnumpy()
|
||||
out = np.abs(x, dtype=x.dtype)
|
||||
out = np.array(np.abs(x, dtype=x.dtype))
|
||||
return Tensor(out)
|
||||
return None
|
||||
|
||||
|
|
|
@ -147,8 +147,8 @@ TEST_F(TestOptLib, test_inline_new_closure) {
|
|||
TEST_F(TestOptLib, test_inline_while) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_while", "before");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
|
||||
FuncGraphPtr after_ = RunSubs(before, patterns);
|
||||
ASSERT_TRUE(CheckOpt(before, before, patterns));
|
||||
FuncGraphPtr after = RunSubs(before, patterns);
|
||||
ASSERT_TRUE(CheckOpt(before, after, patterns, true));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_arithmetic) {
|
||||
|
|
|
@ -520,3 +520,83 @@ def test_while_in_while():
|
|||
out = out + 3
|
||||
return out
|
||||
while_in_while(c1, c2, c3, c4)
|
||||
|
||||
|
||||
def test_tensor_cond():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.t = Tensor(np.array(0, np.bool))
|
||||
self.t1 = Tensor(np.array([True], np.bool))
|
||||
def construct(self, x, y):
|
||||
t = 0
|
||||
if self.t:
|
||||
t = t - x * y
|
||||
else:
|
||||
t = t - x / y
|
||||
if self.t1:
|
||||
t = t + x / y
|
||||
else:
|
||||
t = t + x * y
|
||||
return t
|
||||
|
||||
|
||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
net = Net()
|
||||
out = net(x, y)
|
||||
|
||||
def test_tensor_cond_exception():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.t = Tensor(np.array([True, False], np.bool))
|
||||
def construct(self, x, y):
|
||||
t = 0
|
||||
if self.t:
|
||||
t = t - x * y
|
||||
else:
|
||||
t = t - x / y
|
||||
return t
|
||||
|
||||
|
||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
net = Net()
|
||||
with pytest.raises(ValueError):
|
||||
out = net(x, y)
|
||||
|
||||
def test_while_scalar():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.x = 10
|
||||
def construct(self, x, y):
|
||||
i = 0
|
||||
t = 0
|
||||
while (i < 10):
|
||||
t = t + x + y
|
||||
i = i + 1
|
||||
return t
|
||||
net = Net()
|
||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
out = net(x, y)
|
||||
|
||||
def test_while_tensor():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.t = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
self.count = Tensor(np.array([10], np.int32))
|
||||
def construct(self, x, y):
|
||||
i = 0
|
||||
t = self.t
|
||||
while (i < self.count):
|
||||
t = t + x + y
|
||||
i = i + 1
|
||||
return t
|
||||
net = Net()
|
||||
x = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
y = Tensor(np.ones([6, 8, 10], np.int32))
|
||||
out = net(x, y)
|
||||
|
|
|
@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
|
|||
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
|
||||
from ....mindspore_test_framework.pipeline.forward.verify_exception \
|
||||
import pipeline_for_verify_exception_for_case_by_case_config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
# pylint: disable=W0613
|
||||
# pylint: disable=W0231
|
||||
|
|
|
@ -30,6 +30,7 @@ from ....mindspore_test_framework.utils.check_gradient import (
|
|||
ms_function, check_jacobian, Tensor, NNGradChecker,
|
||||
OperationGradChecker, check_gradient, ScalarGradChecker)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
def setup_module(module):
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
@ -257,8 +258,8 @@ def if_tensor(a, b):
|
|||
|
||||
|
||||
def test_if_tensor():
|
||||
res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
|
||||
assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
|
||||
res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32)))
|
||||
assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
|
||||
|
||||
|
||||
@ms_function
|
||||
|
@ -399,7 +400,7 @@ def if_while(a, b, x, z):
|
|||
def test_if_while():
|
||||
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
|
||||
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
|
||||
res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
|
||||
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
|
||||
assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue