support isinstance(x, CSRTensor) in graph mode

support F.dtype(csr_tensor)
This commit is contained in:
wangrao124 2022-01-18 14:18:07 +08:00
parent cedca0b0c0
commit b24f1f833a
11 changed files with 147 additions and 10 deletions

View File

@ -38,9 +38,8 @@ BuiltInTypeMap &GetMethodMap() {
{"__bool__", std::string("none_bool")} // C.none_bool
}},
{kObjectTypeFunction,
{
{"__bool__", std::string("func_bool")} // C.str_bool
}},
{{"__bool__", std::string("func_bool")}, // C.str_bool
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
{kNumberTypeBool,
{
{"__and__", prim::kPrimBoolAnd}, // P.bool_and

View File

@ -163,6 +163,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -197,5 +197,21 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt
ValuePtr v = args_spec_list[0]->BuildValue();
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
}
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Statement: x.__is_csr_func__()
// Inputs: x
auto func = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
MS_EXCEPTION_IF_NULL(func);
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
MS_EXCEPTION_IF_NULL(prim_func);
PrimitivePtr prim = prim_func->prim();
std::string name = prim->name();
if (name == "S-Prim-MakeCSRTensor") {
return std::make_shared<AbstractScalar>(1);
}
return std::make_shared<AbstractScalar>(0);
}
} // namespace abstract
} // namespace mindspore

View File

@ -126,6 +126,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimInDict, R{InferImplInDict, nullptr, true}},
{prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}},
{prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}},
{prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}},
// Maths
{prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}},
{prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}},

View File

@ -140,6 +140,7 @@ constexpr auto kCSRTensorGetValues = "CSRTensorGetValues";
constexpr auto kCSRTensorGetIndptr = "CSRTensorGetIndptr";
constexpr auto kCSRTensorGetIndices = "CSRTensorGetIndices";
constexpr auto kCSRTensorGetDenseShape = "CSRTensorGetDenseShape";
constexpr auto kIsCSRFunc = "IsCSRFunc";
// Sparse ops
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
@ -583,6 +584,7 @@ MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared
MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared<Primitive>(kCSRTensorGetIndices);
MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetDenseShape =
std::make_shared<Primitive>(kCSRTensorGetDenseShape);
MS_CORE_API inline const PrimitivePtr kPrimIsCSRFunc = std::make_shared<Primitive>(kIsCSRFunc);
// Sparse ops
MS_CORE_API inline const PrimitivePtr kPrimSparseTensorDenseMatmul =

View File

@ -32,10 +32,18 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("dtype infer", int64_t(input_args.size()), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
const std::set<TypePtr> valid_types = {kTensorType};
auto type =
CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name);
return type;
auto type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(type);
if (type->isa<TensorType>()) {
const std::set<TypePtr> valid_types = {kTensorType};
return CheckAndConvertUtils::CheckTensorTypeValid("infer type", type, valid_types, op_name);
} else if (input_args[0]->BuildType()->isa<CSRTensorType>()) {
const std::set<TypePtr> valid_types = {kCSRTensorType};
return CheckAndConvertUtils::CheckCSRTensorTypeValid("infer type", type, valid_types, op_name);
}
MS_EXCEPTION(TypeError) << "For Primitive[" << op_name << "], the input argument[infer type]"
<< "must be a Tensor or CSRTensor but got " << type->ToString() << ".";
return nullptr;
}
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -552,6 +552,20 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
return CheckTensorSubClass(type_name, type, check_list, prim_name);
}
TypePtr CheckAndConvertUtils::CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &check_list,
const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(type);
if (!type->isa<CSRTensorType>()) {
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
<< "] must be a CSRTensor but got " << type->ToString() << ".";
}
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
auto element = csr_tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
return element;
}
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
const std::string &prim_name) {
if (value == nullptr) {

View File

@ -239,6 +239,8 @@ class CheckAndConvertUtils {
const std::string &prim_name);
static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &check_list, const std::string &prim_name);
static TypePtr CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &check_list, const std::string &prim_name);
static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type,
const std::set<TypePtr> &template_types, const std::string &prim_name);
static TypePtr CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,

View File

@ -1536,10 +1536,32 @@ def view(x, *shape):
return F.reshape(x, shape)
@constexpr
def check_is_tuple(x):
"""check whether x is tuple."""
return isinstance(x, mstype.Tuple)
@constexpr
def check_is_func(x):
"""check whether x is function."""
return isinstance(x, mstype.function_type)
def isinstance_(x, base_type):
"""Determine whether x is an instance of base_type."""
"""Determine whether x is an instance of base."""
x_type = F.typeof(x)
return check_type_same(x_type, base_type)
cmp_type = base_type
if check_is_tuple(F.typeof(base_type)):
cmp_type = ()
for i in base_type:
if check_is_func(F.typeof(i)) and i.__is_csr_func__():
cmp_type += (mstype.csr_tensor_type,)
else:
cmp_type += (i,)
if check_is_func(F.typeof(base_type)) and base_type.__is_csr_func__():
cmp_type = mstype.csr_tensor_type
return check_type_same(x_type, cmp_type)
def while_cond(x):
@ -1571,6 +1593,7 @@ def check_type_same(x_type, base_type):
Parameter: mstype.ref_type,
slice: mstype.Slice,
}
sparse_mstype_set = (mstype.csr_tensor_type,)
has_int = False
has_tensor = False
@ -1578,7 +1601,12 @@ def check_type_same(x_type, base_type):
def to_target_type(origin_type):
try:
if isinstance(origin_type, type):
ret_type = pytype_to_mstype[origin_type]
ret_type = None
if origin_type in pytype_to_mstype:
ret_type = pytype_to_mstype[origin_type]
elif origin_type in sparse_mstype_set:
ret_type = origin_type
if ret_type == mstype.Int:
nonlocal has_int
has_int = True

View File

@ -115,6 +115,7 @@ Ellipsis_ = typing.TypeEllipsis
none_type = typing.TypeNone
env_type_type = typing.EnvType
tensor_type = typing.TensorType
csr_tensor_type = typing.CSRTensorType
anything_type = typing.TypeAnything
ref_type = typing.RefType

View File

@ -22,6 +22,7 @@ from mindspore import Tensor, CSRTensor, ms_function, nn, context
from mindspore.ops.operations import _csr_ops
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
from mindspore.ops import functional as F
context.set_context(mode=context.GRAPH_MODE)
@ -364,3 +365,66 @@ def test_csrops_export_mindir():
assert np.allclose(out[4].values.asnumpy(), outputs_after_load[4].values.asnumpy())
assert out[3].shape == outputs_after_load[3].shape
assert out[4].shape == outputs_after_load[4].shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_isinstance_csr_tensor():
"""
Feature: Test isinstance.
Description: Test: isinstance(x, CSRTensor).
Expectation: Success.
"""
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([2, 1], dtype=mstype.float32)
shape = (2, 4)
def pynative_test_csr_tensor():
x = CSRTensor(indptr, indices, values, shape)
# Test input CSRTensor
is_tensor = isinstance(x, Tensor)
is_bool = isinstance(x, bool)
is_float = isinstance(x, float)
is_tuple = isinstance(x, (Tensor, CSRTensor, int, float))
is_csr_tensor = isinstance(x, CSRTensor)
# Test input Tensor
is_tensor_2 = isinstance(indptr, CSRTensor)
is_tuple_2 = isinstance(indptr, (Tensor, CSRTensor))
return is_tensor, is_bool, is_float, is_tuple, is_csr_tensor, is_tensor_2, is_tuple_2
graph_test_csr_tensor = ms_function(pynative_test_csr_tensor)
out1 = pynative_test_csr_tensor()
out2 = graph_test_csr_tensor()
assert out1 == (False, False, False, True, True, False, True)
assert out2 == (False, False, False, True, True, False, True)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dtype_csr_tensor():
"""
Feature: Test F.dtype with CSRTensor.
Description: Test: F.dtype(x).
Expectation: Success.
"""
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
values = Tensor([2, 1], dtype=mstype.float32)
shape = (2, 4)
def pynative_test():
x = CSRTensor(indptr, indices, values, shape)
return F.dtype(x)
graph_test = ms_function(pynative_test)
out1 = pynative_test()
out2 = graph_test()
assert out1 in [mstype.float32]
assert out2 in [mstype.float32]