forked from mindspore-Ecosystem/mindspore
support isinstance(x, CSRTensor) in graph mode
support F.dtype(csr_tensor)
This commit is contained in:
parent
cedca0b0c0
commit
b24f1f833a
|
@ -38,9 +38,8 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__bool__", std::string("none_bool")} // C.none_bool
|
||||
}},
|
||||
{kObjectTypeFunction,
|
||||
{
|
||||
{"__bool__", std::string("func_bool")} // C.str_bool
|
||||
}},
|
||||
{{"__bool__", std::string("func_bool")}, // C.str_bool
|
||||
{"__is_csr_func__", prim::kPrimIsCSRFunc}}},
|
||||
{kNumberTypeBool,
|
||||
{
|
||||
{"__and__", prim::kPrimBoolAnd}, // P.bool_and
|
||||
|
|
|
@ -163,6 +163,8 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -197,5 +197,21 @@ AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePt
|
|||
ValuePtr v = args_spec_list[0]->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(!v->isa<AnyValue>());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Statement: x.__is_csr_func__()
|
||||
// Inputs: x
|
||||
auto func = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
|
||||
MS_EXCEPTION_IF_NULL(prim_func);
|
||||
PrimitivePtr prim = prim_func->prim();
|
||||
std::string name = prim->name();
|
||||
if (name == "S-Prim-MakeCSRTensor") {
|
||||
return std::make_shared<AbstractScalar>(1);
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(0);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -126,6 +126,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimInDict, R{InferImplInDict, nullptr, true}},
|
||||
{prim::kPrimNotInDict, R{InferImplNotInDict, nullptr, true}},
|
||||
{prim::kPrimIsConsant, R{InferImplIsConstant, nullptr, true}},
|
||||
{prim::kPrimIsCSRFunc, R{InferImplIsCSRFunc, nullptr, true}},
|
||||
// Maths
|
||||
{prim::kPrimMatMul, R{InferImplMatMul, nullptr, true}},
|
||||
{prim::kPrimBatchMatMul, R{InferImplBatchMatMul, nullptr, true}},
|
||||
|
|
|
@ -140,6 +140,7 @@ constexpr auto kCSRTensorGetValues = "CSRTensorGetValues";
|
|||
constexpr auto kCSRTensorGetIndptr = "CSRTensorGetIndptr";
|
||||
constexpr auto kCSRTensorGetIndices = "CSRTensorGetIndices";
|
||||
constexpr auto kCSRTensorGetDenseShape = "CSRTensorGetDenseShape";
|
||||
constexpr auto kIsCSRFunc = "IsCSRFunc";
|
||||
|
||||
// Sparse ops
|
||||
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
|
||||
|
@ -583,6 +584,7 @@ MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndptr = std::make_shared
|
|||
MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetIndices = std::make_shared<Primitive>(kCSRTensorGetIndices);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimCSRTensorGetDenseShape =
|
||||
std::make_shared<Primitive>(kCSRTensorGetDenseShape);
|
||||
MS_CORE_API inline const PrimitivePtr kPrimIsCSRFunc = std::make_shared<Primitive>(kIsCSRFunc);
|
||||
|
||||
// Sparse ops
|
||||
MS_CORE_API inline const PrimitivePtr kPrimSparseTensorDenseMatmul =
|
||||
|
|
|
@ -32,10 +32,18 @@ ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<Abstra
|
|||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("dtype infer", int64_t(input_args.size()), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
auto type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name);
|
||||
return type;
|
||||
auto type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->isa<TensorType>()) {
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("infer type", type, valid_types, op_name);
|
||||
} else if (input_args[0]->BuildType()->isa<CSRTensorType>()) {
|
||||
const std::set<TypePtr> valid_types = {kCSRTensorType};
|
||||
return CheckAndConvertUtils::CheckCSRTensorTypeValid("infer type", type, valid_types, op_name);
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "For Primitive[" << op_name << "], the input argument[infer type]"
|
||||
<< "must be a Tensor or CSRTensor but got " << type->ToString() << ".";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -552,6 +552,20 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeValid(const std::string &type_name,
|
|||
return CheckTensorSubClass(type_name, type, check_list, prim_name);
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &check_list,
|
||||
const std::string &prim_name) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (!type->isa<CSRTensorType>()) {
|
||||
MS_EXCEPTION(TypeError) << "For Primitive[" << prim_name << "], the input argument[" << type_name
|
||||
<< "] must be a CSRTensor but got " << type->ToString() << ".";
|
||||
}
|
||||
auto csr_tensor_type = type->cast<CSRTensorTypePtr>();
|
||||
auto element = csr_tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
return element;
|
||||
}
|
||||
|
||||
ShapeVector CheckAndConvertUtils::CheckTensorIntValue(const std::string &type_name, const ValuePtr &value,
|
||||
const std::string &prim_name) {
|
||||
if (value == nullptr) {
|
||||
|
|
|
@ -239,6 +239,8 @@ class CheckAndConvertUtils {
|
|||
const std::string &prim_name);
|
||||
static TypePtr CheckTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &check_list, const std::string &prim_name);
|
||||
static TypePtr CheckCSRTensorTypeValid(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &check_list, const std::string &prim_name);
|
||||
static TypePtr CheckSubClass(const std::string &type_name, const TypePtr &type,
|
||||
const std::set<TypePtr> &template_types, const std::string &prim_name);
|
||||
static TypePtr CheckScalarOrTensorTypesSame(const std::map<std::string, TypePtr> &args,
|
||||
|
|
|
@ -1536,10 +1536,32 @@ def view(x, *shape):
|
|||
return F.reshape(x, shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_tuple(x):
|
||||
"""check whether x is tuple."""
|
||||
return isinstance(x, mstype.Tuple)
|
||||
|
||||
|
||||
@constexpr
|
||||
def check_is_func(x):
|
||||
"""check whether x is function."""
|
||||
return isinstance(x, mstype.function_type)
|
||||
|
||||
|
||||
def isinstance_(x, base_type):
|
||||
"""Determine whether x is an instance of base_type."""
|
||||
"""Determine whether x is an instance of base."""
|
||||
x_type = F.typeof(x)
|
||||
return check_type_same(x_type, base_type)
|
||||
cmp_type = base_type
|
||||
if check_is_tuple(F.typeof(base_type)):
|
||||
cmp_type = ()
|
||||
for i in base_type:
|
||||
if check_is_func(F.typeof(i)) and i.__is_csr_func__():
|
||||
cmp_type += (mstype.csr_tensor_type,)
|
||||
else:
|
||||
cmp_type += (i,)
|
||||
if check_is_func(F.typeof(base_type)) and base_type.__is_csr_func__():
|
||||
cmp_type = mstype.csr_tensor_type
|
||||
return check_type_same(x_type, cmp_type)
|
||||
|
||||
|
||||
def while_cond(x):
|
||||
|
@ -1571,6 +1593,7 @@ def check_type_same(x_type, base_type):
|
|||
Parameter: mstype.ref_type,
|
||||
slice: mstype.Slice,
|
||||
}
|
||||
sparse_mstype_set = (mstype.csr_tensor_type,)
|
||||
|
||||
has_int = False
|
||||
has_tensor = False
|
||||
|
@ -1578,7 +1601,12 @@ def check_type_same(x_type, base_type):
|
|||
def to_target_type(origin_type):
|
||||
try:
|
||||
if isinstance(origin_type, type):
|
||||
ret_type = pytype_to_mstype[origin_type]
|
||||
ret_type = None
|
||||
if origin_type in pytype_to_mstype:
|
||||
ret_type = pytype_to_mstype[origin_type]
|
||||
elif origin_type in sparse_mstype_set:
|
||||
ret_type = origin_type
|
||||
|
||||
if ret_type == mstype.Int:
|
||||
nonlocal has_int
|
||||
has_int = True
|
||||
|
|
|
@ -115,6 +115,7 @@ Ellipsis_ = typing.TypeEllipsis
|
|||
none_type = typing.TypeNone
|
||||
env_type_type = typing.EnvType
|
||||
tensor_type = typing.TensorType
|
||||
csr_tensor_type = typing.CSRTensorType
|
||||
anything_type = typing.TypeAnything
|
||||
ref_type = typing.RefType
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore import Tensor, CSRTensor, ms_function, nn, context
|
|||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import export, load
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -364,3 +365,66 @@ def test_csrops_export_mindir():
|
|||
assert np.allclose(out[4].values.asnumpy(), outputs_after_load[4].values.asnumpy())
|
||||
assert out[3].shape == outputs_after_load[3].shape
|
||||
assert out[4].shape == outputs_after_load[4].shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_isinstance_csr_tensor():
|
||||
"""
|
||||
Feature: Test isinstance.
|
||||
Description: Test: isinstance(x, CSRTensor).
|
||||
Expectation: Success.
|
||||
"""
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
shape = (2, 4)
|
||||
|
||||
def pynative_test_csr_tensor():
|
||||
x = CSRTensor(indptr, indices, values, shape)
|
||||
# Test input CSRTensor
|
||||
is_tensor = isinstance(x, Tensor)
|
||||
is_bool = isinstance(x, bool)
|
||||
is_float = isinstance(x, float)
|
||||
is_tuple = isinstance(x, (Tensor, CSRTensor, int, float))
|
||||
is_csr_tensor = isinstance(x, CSRTensor)
|
||||
|
||||
# Test input Tensor
|
||||
is_tensor_2 = isinstance(indptr, CSRTensor)
|
||||
is_tuple_2 = isinstance(indptr, (Tensor, CSRTensor))
|
||||
return is_tensor, is_bool, is_float, is_tuple, is_csr_tensor, is_tensor_2, is_tuple_2
|
||||
graph_test_csr_tensor = ms_function(pynative_test_csr_tensor)
|
||||
|
||||
out1 = pynative_test_csr_tensor()
|
||||
out2 = graph_test_csr_tensor()
|
||||
assert out1 == (False, False, False, True, True, False, True)
|
||||
assert out2 == (False, False, False, True, True, False, True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dtype_csr_tensor():
|
||||
"""
|
||||
Feature: Test F.dtype with CSRTensor.
|
||||
Description: Test: F.dtype(x).
|
||||
Expectation: Success.
|
||||
"""
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
shape = (2, 4)
|
||||
|
||||
def pynative_test():
|
||||
x = CSRTensor(indptr, indices, values, shape)
|
||||
return F.dtype(x)
|
||||
graph_test = ms_function(pynative_test)
|
||||
|
||||
out1 = pynative_test()
|
||||
out2 = graph_test()
|
||||
assert out1 in [mstype.float32]
|
||||
assert out2 in [mstype.float32]
|
||||
|
|
Loading…
Reference in New Issue