* add isconstant primitive

* add infer_value for common math ops
* convert constant bool tensor to bool value

* do not infer value when encounter 0 as division

for while condition, do not unrool if condition is a tensor
This commit is contained in:
huangdongrun 2020-05-30 10:48:06 +08:00
parent fac6c56db5
commit beacc26077
18 changed files with 168 additions and 11 deletions

View File

@ -126,7 +126,7 @@ convert_object_map = {
T.make_list: F.make_list,
T.make_slice: F.make_slice,
T.range: F.make_range,
T.while_cond: M.while_cond,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,

View File

@ -16,8 +16,10 @@
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore.common import dtype as mstype
from ...ops import functional as F
from ...ops import operations as P
from ...ops.primitive import constexpr
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like
from ...ops.composite.base import _append
@ -102,11 +104,44 @@ def bool_(x):
return x.__bool__()
def tensor_bool(x):
"""return immedate x, x is a tensor of bool value"""
def while_cond(x):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond:
return F.cast(x, mstype.bool_)
return x
@constexpr
def check_is_tensor_bool_cond(shp):
"""check if tensor is a bool condition"""
if shp in ((), (1,)):
return True
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
@constexpr
def const_tensor_to_bool(x):
"""convert bool tensor to bool condition"""
if x is None:
raise ValueError("Only constant tensor bool can be converted to bool")
x = x.asnumpy()
if x.shape not in ((), (1,)):
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
if x.shape == ():
value = bool(x)
else:
value = bool(x[0])
return value
def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value"""
is_cond = check_is_tensor_bool_cond(F.shape(x))
if is_cond and F.isconstant(x):
return const_tensor_to_bool(x)
return F.cast(x, mstype.bool_)
def and_(x, y):
"""Implementation of `and` (`&`)."""
return x.__and__(y)

View File

@ -91,3 +91,7 @@ def to_array(x): # pragma: no cover
def not_contains(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover
"""Not in function."""
raise RuntimeError('This operation is not meant to be called directly.')

View File

@ -281,6 +281,16 @@ class TraceForceBool : public TraceInfo {
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
};
class TraceForceWhileCond : public TraceInfo {
public:
explicit TraceForceWhileCond(const DebugInfoPtr &info) : TraceInfo(info, "force_while_cond", "") {}
MS_DECLARE_PARENT(TraceForceWhileCond, TraceInfo);
~TraceForceWhileCond() override = default;
TraceInfoPtr clone() override {
return std::make_shared<TraceForceWhileCond>(*shared_from_base<TraceForceWhileCond>());
}
};
class TraceExpandJ : public TraceInfo {
public:
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}

View File

@ -243,6 +243,7 @@ const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
// Comm ops
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");

View File

@ -252,6 +252,7 @@ extern const PrimitivePtr kPrimIsNot;
extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
extern const PrimitivePtr kPrimMixedPrecisionCast;
extern const PrimitivePtr kPrimIsConsant;
// Comm ops
extern const PrimitivePtr kPrimMirror;

View File

@ -110,7 +110,8 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
ValuePtr v = cond->GetValueTrack();
MS_EXCEPTION_IF_NULL(v);
if (v->isa<AnyValue>()) {
// for tensor as condition, keeps both true and false branch.
if (v->isa<AnyValue>() || cond->isa<AbstractTensor>()) {
MS_EXCEPTION_IF_NULL(tb);
return tb->Join(fb);
}
@ -228,5 +229,16 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
// Inputs: x, t
return std::make_shared<AbstractScalar>(!IsInDict(primitive, args_spec_list));
}
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// statement: isconstant(x)
// Inputs: x
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1";
}
ValuePtr v = args_spec_list[0]->BuildValue();
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
}
} // namespace abstract
} // namespace mindspore

View File

@ -265,6 +265,13 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) {
return op_apply_node;
}
CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
TraceManager::DebugTrace(std::make_shared<TraceForceWhileCond>(cond->debug_info()));
CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond});
TraceManager::EndTrace();
return op_apply_node;
}
// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) {
if (func_graph()->get_return() != nullptr) {

View File

@ -55,6 +55,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// A block is matured if all its predecessors is generated
void Mature();
CNodePtr ForceToBoolNode(const AnfNodePtr &cond);
CNodePtr ForceToWhileCond(const AnfNodePtr &cond);
void Jump(const FunctionBlockPtr &block, AnfNodePtr node);
AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi);
void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock);

View File

@ -967,6 +967,7 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
py::object test_node = python_adapter::GetPyObjAttr(node, "test");
AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
condition_node = header_block->ForceToWhileCond(condition_node);
body_block->Mature();
header_block->ConditionalJump(condition_node, body_block, after_block);

View File

@ -55,6 +55,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
// Maths
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},

View File

@ -200,6 +200,8 @@ AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -26,6 +26,7 @@ typeof = Primitive('typeof')
hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
issubclass_ = P.IsSubClass()

View File

@ -2294,7 +2294,7 @@ class Abs(PrimitiveWithInfer):
def infer_value(self, x):
if x is not None:
x = x.asnumpy()
out = np.abs(x, dtype=x.dtype)
out = np.array(np.abs(x, dtype=x.dtype))
return Tensor(out)
return None

View File

@ -147,8 +147,8 @@ TEST_F(TestOptLib, test_inline_new_closure) {
TEST_F(TestOptLib, test_inline_while) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_inline_while", "before");
auto patterns = std::vector<SubstitutionPtr>({irpass.inline_});
FuncGraphPtr after_ = RunSubs(before, patterns);
ASSERT_TRUE(CheckOpt(before, before, patterns));
FuncGraphPtr after = RunSubs(before, patterns);
ASSERT_TRUE(CheckOpt(before, after, patterns, true));
}
TEST_F(TestOptLib, test_arithmetic) {

View File

@ -520,3 +520,83 @@ def test_while_in_while():
out = out + 3
return out
while_in_while(c1, c2, c3, c4)
def test_tensor_cond():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.array(0, np.bool))
self.t1 = Tensor(np.array([True], np.bool))
def construct(self, x, y):
t = 0
if self.t:
t = t - x * y
else:
t = t - x / y
if self.t1:
t = t + x / y
else:
t = t + x * y
return t
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
net = Net()
out = net(x, y)
def test_tensor_cond_exception():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.array([True, False], np.bool))
def construct(self, x, y):
t = 0
if self.t:
t = t - x * y
else:
t = t - x / y
return t
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
net = Net()
with pytest.raises(ValueError):
out = net(x, y)
def test_while_scalar():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.x = 10
def construct(self, x, y):
i = 0
t = 0
while (i < 10):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)
def test_while_tensor():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t = Tensor(np.ones([6, 8, 10], np.int32))
self.count = Tensor(np.array([10], np.int32))
def construct(self, x, y):
i = 0
t = self.t
while (i < self.count):
t = t + x + y
i = i + 1
return t
net = Net()
x = Tensor(np.ones([6, 8, 10], np.int32))
y = Tensor(np.ones([6, 8, 10], np.int32))
out = net(x, y)

View File

@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
context.set_context(mode=context.GRAPH_MODE)
# pylint: disable=W0613
# pylint: disable=W0231

View File

@ -30,6 +30,7 @@ from ....mindspore_test_framework.utils.check_gradient import (
ms_function, check_jacobian, Tensor, NNGradChecker,
OperationGradChecker, check_gradient, ScalarGradChecker)
context.set_context(mode=context.PYNATIVE_MODE)
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
@ -257,8 +258,8 @@ def if_tensor(a, b):
def test_if_tensor():
res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32)))
assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
@ms_function
@ -399,7 +400,7 @@ def if_while(a, b, x, z):
def test_if_while():
x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)