diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 022ed9acbe7..298ffb97747 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -38,9 +38,8 @@ BuiltInTypeMap &GetMethodMap() { {"__bool__", std::string("none_bool")} // C.none_bool }}, {kObjectTypeFunction, - { - {"__bool__", std::string("func_bool")} // C.str_bool - }}, + {{"__bool__", std::string("func_bool")}, // C.str_bool + {"__is_csr_func__", prim::kPrimIsCSRFunc}}}, {kNumberTypeBool, { {"__and__", prim::kPrimBoolAnd}, // P.bool_and diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 8360b8b295e..06e8d51cdd5 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -163,6 +163,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc index 7ff9ae03ca9..800ee5d0807 100644 --- a/mindspore/core/abstract/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -197,5 +197,21 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt ValuePtr v = args_spec_list[0]->BuildValue(); return std::make_shared(!v->isa()); } + +AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Statement: x.__is_csr_func__() + // Inputs: x + auto func = CheckArg(primitive->name(), args_spec_list, 0); + MS_EXCEPTION_IF_NULL(func); + auto prim_func = dyn_cast(func); + MS_EXCEPTION_IF_NULL(prim_func); + PrimitivePtr prim = prim_func->prim(); + std::string name = prim->name(); + if (name == "S-Prim-MakeCSRTensor") { + return std::make_shared(1); + } + return std::make_shared(0); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index d2cfc490163..29add67b060 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -126,6 +126,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimInDict, R{InferImplInDict, nullptr, true}}, {prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}}, {prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}}, + {prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}}, // Maths {prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}}, {prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 1a54ec6cb53..7ac07eb6426 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -140,6 +140,7 @@ constexpr auto kCSRTensorGetValues = "CSRTensorGetValues"; constexpr auto kCSRTensorGetIndptr = "CSRTensorGetIndptr"; constexpr auto kCSRTensorGetIndices = "CSRTensorGetIndices"; constexpr auto kCSRTensorGetDenseShape = "CSRTensorGetDenseShape"; +constexpr auto kIsCSRFunc = "IsCSRFunc"; // Sparse ops constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul"; @@ -583,6 +584,7 @@ MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared(kCSRTensorGetIndices); MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetDenseShape = std::make_shared(kCSRTensorGetDenseShape); +MS_CORE_API inline const PrimitivePtr kPrimIsCSRFunc = std::make_shared(kIsCSRFunc); // Sparse ops MS_CORE_API inline const PrimitivePtr kPrimSparseTensorDenseMatmul = diff --git a/mindspore/core/ops/dtype.cc b/mindspore/core/ops/dtype.cc index f5207fd1f20..76f705f9d54 100644 --- a/mindspore/core/ops/dtype.cc +++ b/mindspore/core/ops/dtype.cc @@ -32,10 +32,18 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vectorname(); (void)CheckAndConvertUtils::CheckInteger("dtype infer", int64_t(input_args.size()), kEqual, 1, op_name); MS_EXCEPTION_IF_NULL(input_args[0]); - const std::set valid_types = {kTensorType}; - auto type = - CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name); - return type; + auto type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(type); + if (type->isa()) { + const std::set valid_types = {kTensorType}; + return CheckAndConvertUtils::CheckTensorTypeValid("infer type", type, valid_types, op_name); + } else if (input_args[0]->BuildType()->isa()) { + const std::set valid_types = {kCSRTensorType}; + return CheckAndConvertUtils::CheckCSRTensorTypeValid("infer type", type, valid_types, op_name); + } + MS_EXCEPTION(TypeError) << "For Primitive[" << op_name << "], the input argument[infer type]" + << "must be a Tensor or CSRTensor but got " << type->ToString() << "."; + return nullptr; } AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index dc93844665b..fe08ed53bbc 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -552,6 +552,20 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name, return CheckTensorSubClass(type_name, type, check_list, prim_name); } +TypePtr CheckAndConvertUtils::CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type, + const std::set &check_list, + const std::string &prim_name) { + MS_EXCEPTION_IF_NULL(type); + if (!type->isa()) { + MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name + << "] must be a CSRTensor but got " << type->ToString() << "."; + } + auto csr_tensor_type = type->cast(); + auto element = csr_tensor_type->element(); + MS_EXCEPTION_IF_NULL(element); + return element; +} + ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value, const std::string &prim_name) { if (value == nullptr) { diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index c639b313ee7..2c5b07e97bb 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -239,6 +239,8 @@ class CheckAndConvertUtils { const std::string &prim_name); static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type, const std::set &check_list, const std::string &prim_name); + static TypePtr CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type, + const std::set &check_list, const std::string &prim_name); static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type, const std::set &template_types, const std::string &prim_name); static TypePtr CheckScalarOrTensorTypesSame(const std::map &args, diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index 38642a6b415..ed0f4d7cf9e 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -1536,10 +1536,32 @@ def view(x, *shape): return F.reshape(x, shape) +@constexpr +def check_is_tuple(x): + """check whether x is tuple.""" + return isinstance(x, mstype.Tuple) + + +@constexpr +def check_is_func(x): + """check whether x is function.""" + return isinstance(x, mstype.function_type) + + def isinstance_(x, base_type): - """Determine whether x is an instance of base_type.""" + """Determine whether x is an instance of base.""" x_type = F.typeof(x) - return check_type_same(x_type, base_type) + cmp_type = base_type + if check_is_tuple(F.typeof(base_type)): + cmp_type = () + for i in base_type: + if check_is_func(F.typeof(i)) and i.__is_csr_func__(): + cmp_type += (mstype.csr_tensor_type,) + else: + cmp_type += (i,) + if check_is_func(F.typeof(base_type)) and base_type.__is_csr_func__(): + cmp_type = mstype.csr_tensor_type + return check_type_same(x_type, cmp_type) def while_cond(x): @@ -1571,6 +1593,7 @@ def check_type_same(x_type, base_type): Parameter: mstype.ref_type, slice: mstype.Slice, } + sparse_mstype_set = (mstype.csr_tensor_type,) has_int = False has_tensor = False @@ -1578,7 +1601,12 @@ def check_type_same(x_type, base_type): def to_target_type(origin_type): try: if isinstance(origin_type, type): - ret_type = pytype_to_mstype[origin_type] + ret_type = None + if origin_type in pytype_to_mstype: + ret_type = pytype_to_mstype[origin_type] + elif origin_type in sparse_mstype_set: + ret_type = origin_type + if ret_type == mstype.Int: nonlocal has_int has_int = True diff --git a/mindspore/python/mindspore/common/dtype.py b/mindspore/python/mindspore/common/dtype.py index 971db62a541..3632cfe7d8d 100644 --- a/mindspore/python/mindspore/common/dtype.py +++ b/mindspore/python/mindspore/common/dtype.py @@ -115,6 +115,7 @@ Ellipsis_ = typing.TypeEllipsis none_type = typing.TypeNone env_type_type = typing.EnvType tensor_type = typing.TensorType +csr_tensor_type = typing.CSRTensorType anything_type = typing.TypeAnything ref_type = typing.RefType diff --git a/tests/st/sparse/test_csr.py b/tests/st/sparse/test_csr.py index f62fc4536a7..bca9a382547 100644 --- a/tests/st/sparse/test_csr.py +++ b/tests/st/sparse/test_csr.py @@ -22,6 +22,7 @@ from mindspore import Tensor, CSRTensor, ms_function, nn, context from mindspore.ops.operations import _csr_ops from mindspore.common import dtype as mstype from mindspore.train.serialization import export, load +from mindspore.ops import functional as F context.set_context(mode=context.GRAPH_MODE) @@ -364,3 +365,66 @@ def test_csrops_export_mindir(): assert np.allclose(out[4].values.asnumpy(), outputs_after_load[4].values.asnumpy()) assert out[3].shape == outputs_after_load[3].shape assert out[4].shape == outputs_after_load[4].shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isinstance_csr_tensor(): + """ + Feature: Test isinstance. + Description: Test: isinstance(x, CSRTensor). + Expectation: Success. + """ + indptr = Tensor([0, 1, 2]) + indices = Tensor([0, 1]) + values = Tensor([2, 1], dtype=mstype.float32) + shape = (2, 4) + + def pynative_test_csr_tensor(): + x = CSRTensor(indptr, indices, values, shape) + # Test input CSRTensor + is_tensor = isinstance(x, Tensor) + is_bool = isinstance(x, bool) + is_float = isinstance(x, float) + is_tuple = isinstance(x, (Tensor, CSRTensor, int, float)) + is_csr_tensor = isinstance(x, CSRTensor) + + # Test input Tensor + is_tensor_2 = isinstance(indptr, CSRTensor) + is_tuple_2 = isinstance(indptr, (Tensor, CSRTensor)) + return is_tensor, is_bool, is_float, is_tuple, is_csr_tensor, is_tensor_2, is_tuple_2 + graph_test_csr_tensor = ms_function(pynative_test_csr_tensor) + + out1 = pynative_test_csr_tensor() + out2 = graph_test_csr_tensor() + assert out1 == (False, False, False, True, True, False, True) + assert out2 == (False, False, False, True, True, False, True) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dtype_csr_tensor(): + """ + Feature: Test F.dtype with CSRTensor. + Description: Test: F.dtype(x). + Expectation: Success. + """ + indptr = Tensor([0, 1, 2]) + indices = Tensor([0, 1]) + values = Tensor([2, 1], dtype=mstype.float32) + shape = (2, 4) + + def pynative_test(): + x = CSRTensor(indptr, indices, values, shape) + return F.dtype(x) + graph_test = ms_function(pynative_test) + + out1 = pynative_test() + out2 = graph_test() + assert out1 in [mstype.float32] + assert out2 in [mstype.float32]