From e4c1900fe8cbf1ea279d859609d8505b0c096866 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Tue, 2 Aug 2022 20:03:35 +0800 Subject: [PATCH] move sparse infer functions --- mindspore/core/abstract/ops/infer_functions.h | 36 -- mindspore/core/abstract/ops/prim_others.cc | 526 ------------------ .../core/abstract/ops/primitive_infer_map.cc | 19 - mindspore/core/ops/coo_tensor_get_indices.cc | 41 ++ mindspore/core/ops/coo_tensor_get_indices.h | 39 ++ mindspore/core/ops/coo_tensor_get_shape.cc | 41 ++ mindspore/core/ops/coo_tensor_get_shape.h | 39 ++ mindspore/core/ops/coo_tensor_get_values.cc | 40 ++ mindspore/core/ops/coo_tensor_get_values.h | 39 ++ mindspore/core/ops/coo_to_csr.cc | 59 ++ mindspore/core/ops/coo_to_csr.h | 43 ++ mindspore/core/ops/csr_elementwise.cc | 68 +++ mindspore/core/ops/csr_elementwise.h | 56 ++ mindspore/core/ops/csr_gather.cc | 70 +++ mindspore/core/ops/csr_gather.h | 43 ++ mindspore/core/ops/csr_mm.cc | 78 +++ mindspore/core/ops/csr_mm.h | 45 ++ mindspore/core/ops/csr_mv.cc | 77 +++ mindspore/core/ops/csr_mv.h | 45 ++ mindspore/core/ops/csr_reducesum.cc | 85 +++ mindspore/core/ops/csr_reducesum.h | 45 ++ mindspore/core/ops/csr_tensor_get_indices.cc | 41 ++ mindspore/core/ops/csr_tensor_get_indices.h | 40 ++ mindspore/core/ops/csr_tensor_get_indptr.cc | 40 ++ mindspore/core/ops/csr_tensor_get_indptr.h | 40 ++ mindspore/core/ops/csr_tensor_get_shape.cc | 41 ++ mindspore/core/ops/csr_tensor_get_shape.h | 39 ++ mindspore/core/ops/csr_tensor_get_values.cc | 40 ++ mindspore/core/ops/csr_tensor_get_values.h | 40 ++ mindspore/core/ops/csr_to_coo.cc | 67 +++ mindspore/core/ops/csr_to_coo.h | 43 ++ mindspore/core/ops/make_cootensor.cc | 104 ++++ mindspore/core/ops/make_cootensor.h | 39 ++ mindspore/core/ops/make_csrtensor.cc | 104 ++++ mindspore/core/ops/make_csrtensor.h | 40 ++ mindspore/core/ops/op_utils.cc | 94 ++++ mindspore/core/ops/op_utils.h | 18 + mindspore/core/ops/sparse_matrix_add.cc | 15 - 38 files changed, 1783 insertions(+), 596 deletions(-) create mode 100644 mindspore/core/ops/coo_tensor_get_indices.cc create mode 100644 mindspore/core/ops/coo_tensor_get_indices.h create mode 100644 mindspore/core/ops/coo_tensor_get_shape.cc create mode 100644 mindspore/core/ops/coo_tensor_get_shape.h create mode 100644 mindspore/core/ops/coo_tensor_get_values.cc create mode 100644 mindspore/core/ops/coo_tensor_get_values.h create mode 100644 mindspore/core/ops/coo_to_csr.cc create mode 100644 mindspore/core/ops/coo_to_csr.h create mode 100644 mindspore/core/ops/csr_elementwise.cc create mode 100644 mindspore/core/ops/csr_elementwise.h create mode 100644 mindspore/core/ops/csr_gather.cc create mode 100644 mindspore/core/ops/csr_gather.h create mode 100644 mindspore/core/ops/csr_mm.cc create mode 100644 mindspore/core/ops/csr_mm.h create mode 100644 mindspore/core/ops/csr_mv.cc create mode 100644 mindspore/core/ops/csr_mv.h create mode 100644 mindspore/core/ops/csr_reducesum.cc create mode 100644 mindspore/core/ops/csr_reducesum.h create mode 100644 mindspore/core/ops/csr_tensor_get_indices.cc create mode 100644 mindspore/core/ops/csr_tensor_get_indices.h create mode 100644 mindspore/core/ops/csr_tensor_get_indptr.cc create mode 100644 mindspore/core/ops/csr_tensor_get_indptr.h create mode 100644 mindspore/core/ops/csr_tensor_get_shape.cc create mode 100644 mindspore/core/ops/csr_tensor_get_shape.h create mode 100644 mindspore/core/ops/csr_tensor_get_values.cc create mode 100644 mindspore/core/ops/csr_tensor_get_values.h create mode 100644 mindspore/core/ops/csr_to_coo.cc create mode 100644 mindspore/core/ops/csr_to_coo.h create mode 100644 mindspore/core/ops/make_cootensor.cc create mode 100644 mindspore/core/ops/make_cootensor.h create mode 100644 mindspore/core/ops/make_csrtensor.cc create mode 100644 mindspore/core/ops/make_csrtensor.h diff --git a/mindspore/core/abstract/ops/infer_functions.h b/mindspore/core/abstract/ops/infer_functions.h index 7911a494050..dda367fc35b 100644 --- a/mindspore/core/abstract/ops/infer_functions.h +++ b/mindspore/core/abstract/ops/infer_functions.h @@ -129,30 +129,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -template -std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -163,18 +139,6 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/ops/prim_others.cc b/mindspore/core/abstract/ops/prim_others.cc index 092b524b83c..2af831116af 100644 --- a/mindspore/core/abstract/ops/prim_others.cc +++ b/mindspore/core/abstract/ops/prim_others.cc @@ -34,66 +34,10 @@ namespace { constexpr auto kRankSize = "rank_size"; -inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) { - constexpr auto kCSRMulBatchPos = 2; - int dlen = mindspore::SizeToInt(sparse_shp.size()) - mindspore::SizeToInt(dense_shp.size()); - if (dlen < 0) { - MS_EXCEPTION(mindspore::ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, " - << "but sparse tensor has " << sparse_shp.size() << " dimensions, " - << "and dense tensor has " << dense_shp.size() << " dimensions, "; - } - for (int i = 0; i < dlen; i++) { - (void)dense_shp.insert(dense_shp.begin(), 1); - } - if (sparse_shp.size() != dense_shp.size()) { - MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size()."; - } - if (sparse_shp.size() < 1) { - MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero."; - } - for (size_t i = 0; i < sparse_shp.size(); i++) { - auto s = sparse_shp[i]; - auto d = dense_shp[i]; - if (i < kCSRMulBatchPos) { - if (d != s && d != 1) { - MS_EXCEPTION(mindspore::ValueError) << "Dense shape cannot broadcast to sparse shape."; - } - } else { - if (d != s) { - MS_EXCEPTION(mindspore::ValueError) - << "Currently, sparse shape and dense shape must equal in feature dimensions."; - } - } - } -} -inline void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) { - if (shape_size != expected_dim) { - MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim - << "-dimensional tensor, but got a " << shape_size << "-dimensional tensor."; - } -} -inline void CheckSparseIndicesDtype(const mindspore::TypePtr data_type, const std::string &arg_name) { - if (!(data_type->equal(mindspore::kInt16) || data_type->equal(mindspore::kInt32) || - data_type->equal(mindspore::kInt64))) { - MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got " - << data_type->ToString() << "."; - } -} -inline void CheckSparseIndicesDtypeInt32(const mindspore::TypePtr data_type, const std::string &arg_name) { - if (!data_type->equal(mindspore::kInt32)) { - MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got " - << data_type->ToString() << "."; - } -} } // namespace namespace mindspore { namespace abstract { -constexpr auto kCSRDenseShape = "dense_shape"; -constexpr auto kCSRAxis = "axis"; -constexpr auto kCSRAvgRows = "csr_avg_rows"; -constexpr auto kIsCSR = "is_csr"; - AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // An object of a subclass of AbstractBase @@ -343,433 +287,6 @@ AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const Primitive return args_spec_list[0]; } -AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kSizeThree); - auto indices = CheckArg(op_name, args_spec_list, kIndexZero); - auto values = CheckArg(op_name, args_spec_list, kIndexOne); - auto dense_shape = CheckArg(op_name, args_spec_list, kIndexTwo); - - auto indices_dtype = indices->element()->BuildType(); - CheckSparseIndicesDtype(indices_dtype, "Indices"); - - auto indices_shp = indices->shape()->shape(); - CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices"); - - auto values_shp = values->shape()->shape(); - CheckSparseShape(values_shp.size(), kSizeOne, "Values"); - - if (indices_shp[kIndexZero] != values_shp[kIndexZero]) { - MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape[" - << kIndexZero << "]`, but got `indices.shape[" << kIndexZero - << "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero - << "]`: " << values_shp[kIndexZero]; - } - constexpr int64_t kDimTwo = 2; - if (indices_shp[kIndexOne] != kDimTwo) { - MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got " - << indices_shp[kIndexOne]; - } - - for (const auto &elem_type : dense_shape->ElementsType()) { - if (!elem_type->isa()) { - MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got " - << elem_type->ToString(); - } - } - auto dense_shape_value = dense_shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(dense_shape_value); - auto shp = dense_shape_value->value(); - auto min_elem = *std::min_element(std::begin(shp), std::end(shp)); - if (min_elem <= 0) { - MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem - << "int it"; - } - ShapeVector dense_shape_vec; - (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), - [](const ValuePtr &e) -> int64_t { - auto elem = GetValue(e); - return elem; - }); - if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) { - MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second " - << "dimension of `indices`: " << dense_shape_vec.size() << " but got " - << indices_shp[kIndexOne]; - } - for (auto dense_shape_elem : dense_shape_vec) { - if (dense_shape_elem <= 0) { - MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got " - << dense_shape_value->ToString(); - } - } - AbstractBasePtrList element_list{indices, values, dense_shape}; - return std::make_shared(element_list); -} - -AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(sparse_tensor->values()); - return sparse_tensor->values(); -} - -AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(sparse_tensor->indices()); - return sparse_tensor->indices(); -} - -AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(sparse_tensor->shape()); - return sparse_tensor->shape(); -} - -ShapeVector ConvertToShapeVector(const AbstractTuplePtr &shape) { - auto shape_value = shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(shape_value); - ShapeVector shape_vec; - (void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec), - [](const ValuePtr &e) -> int64_t { - auto elem = GetValue(e); - return elem; - }); - return shape_vec; -} - -AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a sparse tensor and a dense tensor. - constexpr auto kCSRElementwiseInputsNum = 5; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRElementwiseInputsNum); - auto indptr = CheckArg(op_name, args_spec_list, 0); - auto indices = CheckArg(op_name, args_spec_list, 1); - auto values = CheckArg(op_name, args_spec_list, 2); - auto shape = CheckArg(op_name, args_spec_list, 3); - auto dense = CheckArg(op_name, args_spec_list, 4); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(values); - MS_EXCEPTION_IF_NULL(shape); - MS_EXCEPTION_IF_NULL(dense); - - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); - - ShapeVector sparse_shape = ConvertToShapeVector(shape); - auto dense_shape = dense->shape()->shape(); - CheckSparseShape(sparse_shape, dense_shape); - auto ret = values->Broaden(); - // SetAttr - auto nnz_vec = indices->shape()->shape(); - auto csr_avg_rows = nnz_vec[0] / dense_shape[0]; - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - return ret; -} - -AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - constexpr auto kCSRMVInputsNum = 5; - constexpr auto kCSRMVShapeSize = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRMVInputsNum); - auto indptr = CheckArg(op_name, args_spec_list, 0); - auto indices = CheckArg(op_name, args_spec_list, 1); - auto values = CheckArg(op_name, args_spec_list, 2); - auto shape = CheckArg(op_name, args_spec_list, 3); - auto dense = CheckArg(op_name, args_spec_list, 4); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(values); - MS_EXCEPTION_IF_NULL(shape); - MS_EXCEPTION_IF_NULL(dense); - - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); - - ShapeVector sparse_shape = ConvertToShapeVector(shape); - ShapeVector dense_shape = dense->shape()->shape(); - if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) { - MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! " - << "But csr tensor has " << sparse_shape.size() << " dimensions, " - << "and dense tensor has " << dense_shape.size() << " dimension(s). "; - } - if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) { - MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)" - << ", but its current shape is: " - << "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ")."; - } - - ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]}; - auto ret = std::make_shared(values->element()->BuildType(), out_shape); - // SetAttr - auto nnz_vec = indices->shape()->shape(); - auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero]; - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - return ret; -} - -AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a sparse tensor and an axis. - constexpr auto kCSRReduceSumInputsNum = 5; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRReduceSumInputsNum); - auto indptr = CheckArg(op_name, args_spec_list, 0); - auto indices = CheckArg(op_name, args_spec_list, 1); - auto values = CheckArg(op_name, args_spec_list, 2); - auto shape = CheckArg(op_name, args_spec_list, 3); - auto axis = CheckArg(op_name, args_spec_list, 4); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(values); - MS_EXCEPTION_IF_NULL(shape); - MS_EXCEPTION_IF_NULL(axis); - - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); - - ShapeVector sparse_shape = ConvertToShapeVector(shape); - ShapeVector out_shape = sparse_shape; - MS_EXCEPTION_IF_NULL(axis->BuildValue()); - if (axis->BuildValue()->isa() || axis->BuildValue()->isa()) { - int64_t axis_value = GetValue(axis->BuildValue()); - int64_t dim = static_cast(sparse_shape.size()); - if (axis_value != 1 && axis_value != 1 - dim) { - MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value - << "and `1- dim`: " << 1 - dim; - } - if (axis_value < 0) { - axis_value += dim; - } - out_shape[LongToSize(axis_value)] = 1; - primitive->set_attr(kCSRAxis, MakeValue(axis_value)); - } else { - MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got " - << axis->BuildType()->ToString(); - } - - MS_EXCEPTION_IF_NULL(values->element()); - auto ret = std::make_shared(values->element()->BuildType(), out_shape); - // SetAttr - auto nnz_vec = indices->shape()->shape(); - auto csr_avg_rows = nnz_vec[0] / sparse_shape[0]; - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - return ret; -} - -AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor. - constexpr size_t csr_row_num = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kSizeFour); - auto indptr = CheckArg(op_name, args_spec_list, kIndexZero); - auto indices = CheckArg(op_name, args_spec_list, kIndexOne); - auto dense = CheckArg(op_name, args_spec_list, kIndexTwo); - auto sparse_shape = CheckArg(op_name, args_spec_list, kIndexThree); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(dense); - MS_EXCEPTION_IF_NULL(sparse_shape); - - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); - - auto shape_value = sparse_shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(shape_value); - auto nnz_vec = indices->shape()->shape(); - int64_t csr_avg_rows = nnz_vec[0] / GetValue(shape_value->value()[0]); - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - - MS_EXCEPTION_IF_NULL(indices->shape()); - ShapeVector out_shape = indices->shape()->shape(); - MS_EXCEPTION_IF_NULL(dense->shape()); - ShapeVector dense_shape = dense->shape()->shape(); - for (size_t i = csr_row_num; i < dense_shape.size(); ++i) { - out_shape.push_back(dense_shape[i]); - } - MS_EXCEPTION_IF_NULL(dense->element()); - auto ret = std::make_shared(dense->element()->BuildType(), out_shape); - return ret; -} - -AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements. - constexpr auto kCSRArgsSize = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRArgsSize); - auto indptr = CheckArg(op_name, args_spec_list, 0); - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - - auto nnz = CheckArg(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(nnz); - - MS_EXCEPTION_IF_NULL(nnz->BuildValue()); - ShapeVector out_shape; - if (nnz->BuildValue()->isa() || nnz->BuildValue()->isa()) { - int64_t nnz_value = GetValue(nnz->BuildValue()); - out_shape.push_back(nnz_value); - } else { - MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz."; - } - - MS_EXCEPTION_IF_NULL(indptr->shape()); - auto num_rows = indptr->shape()->shape()[0] - 1; - int csr_avg_rows = GetValue(nnz->BuildValue()) / num_rows; - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - - MS_EXCEPTION_IF_NULL(indptr->element()); - auto ret = std::make_shared(indptr->element()->BuildType(), out_shape); - return ret; -} - -AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: the row indices of a sparse coo tensor, and the size of its first dimension. - constexpr auto kCSRArgsSize = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRArgsSize); - auto row_indices = CheckArg(op_name, args_spec_list, 0); - auto height = CheckArg(op_name, args_spec_list, 1); - MS_EXCEPTION_IF_NULL(row_indices); - MS_EXCEPTION_IF_NULL(height); - CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices"); - MS_EXCEPTION_IF_NULL(height->BuildValue()); - ShapeVector out_shape; - if (height->BuildValue()->isa() || height->BuildValue()->isa()) { - int64_t height_value = GetValue(height->BuildValue()); - out_shape.push_back(height_value + 1); - } else { - MS_EXCEPTION(ValueError) << "Currently, only support Integer height."; - } - - MS_EXCEPTION_IF_NULL(row_indices->element()); - auto ret = std::make_shared(row_indices->element()->BuildType(), out_shape); - return ret; -} - -AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kSizeFour); - auto indptr = CheckArg(op_name, args_spec_list, kIndexZero); - auto indices = CheckArg(op_name, args_spec_list, kIndexOne); - auto values = CheckArg(op_name, args_spec_list, kIndexTwo); - auto shape = CheckArg(op_name, args_spec_list, kIndexThree); - - auto indptr_dtype = indptr->element()->BuildType(); - auto indices_dtype = indices->element()->BuildType(); - CheckSparseIndicesDtype(indptr_dtype, "indptr"); - CheckSparseIndicesDtype(indices_dtype, "indices"); - - auto indptr_shp = indptr->shape()->shape(); - CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr"); - - auto indices_shp = indices->shape()->shape(); - CheckSparseShape(indices_shp.size(), kSizeOne, "Indices"); - - auto values_shp = values->shape()->shape(); - if (indices_shp[kIndexZero] != values_shp[kIndexZero]) { - MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: " - << values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero]; - } - - auto shape_value = shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(shape_value); - auto shp = shape_value->value(); - ShapeVector shape_vec; - (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t { - auto elem = GetValue(e); - return elem; - }); - if (values_shp.size() + 1 != shape_vec.size()) { - MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got" - << "Values' dimension: " << values_shp.size() - << ", csr_tensor's dimension: " << shape_vec.size() << "."; - } - if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) { - MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero]; - } - size_t shape_size = 1; - auto shape_types = shape->ElementsType(); - for (size_t i = 0; i < shape_vec.size(); ++i) { - if (shape_vec[i] <= 0) { - MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString(); - } - if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) { - MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape."; - } - if (!shape_types[i]->isa()) { - MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString(); - } - shape_size *= LongToSize(shape_vec[i]); - } - if (static_cast(shape_size) < values_shp[kIndexZero]) { - MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero] - << " non-zero values."; - } - AbstractBasePtrList element_list{indptr, indices, values, shape}; - return std::make_shared(element_list); -} - -template -std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - return CheckArg(op_name, args_spec_list, 0); -} - -AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - auto csr_tensor = InferSparseAttr(primitive, args_spec_list); - MS_EXCEPTION_IF_NULL(csr_tensor->values()); - return csr_tensor->values(); -} - -AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - auto csr_tensor = InferSparseAttr(primitive, args_spec_list); - MS_EXCEPTION_IF_NULL(csr_tensor->indptr()); - return csr_tensor->indptr(); -} - -AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - auto csr_tensor = InferSparseAttr(primitive, args_spec_list); - MS_EXCEPTION_IF_NULL(csr_tensor->indices()); - return csr_tensor->indices(); -} - -AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - auto csr_tensor = InferSparseAttr(primitive, args_spec_list); - MS_EXCEPTION_IF_NULL(csr_tensor->shape()); - return csr_tensor->shape(); -} - AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); @@ -1008,48 +525,5 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const AbstractBasePtrList rets = {add1, add0, sub0}; return std::make_shared(rets); } -AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a sparse tensor and a dense tensor. - constexpr auto kCSRMMInputsNum = 5; - constexpr auto kCSRMMShapeSize = 2; - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, kCSRMMInputsNum); - auto indptr = CheckArg(op_name, args_spec_list, 0); - auto indices = CheckArg(op_name, args_spec_list, 1); - auto values = CheckArg(op_name, args_spec_list, 2); - auto shape = CheckArg(op_name, args_spec_list, 3); - auto dense = CheckArg(op_name, args_spec_list, 4); - MS_EXCEPTION_IF_NULL(indptr); - MS_EXCEPTION_IF_NULL(indices); - MS_EXCEPTION_IF_NULL(values); - MS_EXCEPTION_IF_NULL(shape); - MS_EXCEPTION_IF_NULL(dense); - - CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); - CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); - - ShapeVector sparse_shape = ConvertToShapeVector(shape); - auto dense_shape = dense->shape()->shape(); - if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) { - MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! " - << "But csr tensor has " << sparse_shape.size() << " dimensions, " - << "and dense tensor has " << dense_shape.size() << " dimensions, "; - } - if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) { - MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]" - << ", but dense's shape[0] is: " << dense_shape[kIndexZero] - << " and csr tensor's shape[1] is " << sparse_shape[kIndexOne]; - } - - ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]}; - auto ret = std::make_shared(values->element()->BuildType(), out_shape); - // SetAttr - auto nnz_vec = indices->shape()->shape(); - auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero]; - primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); - primitive->set_attr(kIsCSR, MakeValue(true)); - return ret; -} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index 3de64638f76..46de748e1f6 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -318,31 +318,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimDebug, R{InferImplDebug, nullptr, true}}, // Dynamic shape testing {prim::kPrimGpuConvertToDynamicShape, R{InferImplGpuConvertToDynamicShape, nullptr, true}}, - // COOTensor - {prim::kPrimMakeCOOTensor, R{InferImplMakeCOOTensor, nullptr, true}}, - {prim::kPrimCOOTensorGetValues, R{InferImplCOOTensorGetValues, nullptr, true}}, - {prim::kPrimCOOTensorGetIndices, R{InferImplCOOTensorGetIndices, nullptr, true}}, - {prim::kPrimCOOTensorGetDenseShape, R{InferImplCOOTensorGetDenseShape, nullptr, true}}, // RowTensor {prim::kPrimMakeRowTensor, R{InferImplMakeRowTensor, nullptr, true}}, {prim::kPrimRowTensorGetValues, R{InferImplRowTensorGetValues, nullptr, true}}, {prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}}, {prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}}, {prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}}, - // CSRTensor - {prim::kPrimMakeCSRTensor, R{InferImplMakeCSRTensor, nullptr, true}}, - {prim::kPrimCSRTensorGetValues, R{InferImplCSRTensorGetValues, nullptr, true}}, - {prim::kPrimCSRTensorGetIndptr, R{InferImplCSRTensorGetIndptr, nullptr, true}}, - {prim::kPrimCSRTensorGetIndices, R{InferImplCSRTensorGetIndices, nullptr, true}}, - {prim::kPrimCSRTensorGetDenseShape, R{InferImplCSRTensorGetDenseShape, nullptr, true}}, - {prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}}, - {prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}}, - {prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}}, - {prim::kPrimCSRMM, R{InferImplCSRMM, nullptr, true}}, - {prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}}, - {prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}}, - {prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}}, - {prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}}, // Comm Ops {prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}}, {prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}}, diff --git a/mindspore/core/ops/coo_tensor_get_indices.cc b/mindspore/core/ops/coo_tensor_get_indices.cc new file mode 100644 index 00000000000..495796f107f --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_indices.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/coo_tensor_get_indices.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto coo_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(coo_tensor->indices()); + return coo_tensor->indices(); +} +MIND_API_OPERATOR_IMPL(COOTensorGetIndices, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetIndices, prim::kPrimCOOTensorGetIndices, COOTensorGetIndicesInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/coo_tensor_get_indices.h b/mindspore/core/ops/coo_tensor_get_indices.h new file mode 100644 index 00000000000..8a8893918c4 --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_indices.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_ +#define MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCOOTensorGetIndices = "COOTensorGetIndices"; +class MIND_API COOTensorGetIndices : public BaseOperator { + public: + MIND_API_BASE_MEMBER(COOTensorGetIndices); + /// \brief Constructor. + COOTensorGetIndices() : BaseOperator(kNameCOOTensorGetIndices) {} +}; +abstract::AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_ diff --git a/mindspore/core/ops/coo_tensor_get_shape.cc b/mindspore/core/ops/coo_tensor_get_shape.cc new file mode 100644 index 00000000000..6ebd7b48667 --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_shape.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/coo_tensor_get_shape.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto coo_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(coo_tensor->shape()); + return coo_tensor->shape(); +} +MIND_API_OPERATOR_IMPL(COOTensorGetShape, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetShape, prim::kPrimCOOTensorGetDenseShape, COOTensorGetShapeInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/coo_tensor_get_shape.h b/mindspore/core/ops/coo_tensor_get_shape.h new file mode 100644 index 00000000000..82ca268ea10 --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_shape.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_ +#define MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCOOTensorGetShape = "COOTensorGetShape"; +class MIND_API COOTensorGetShape : public BaseOperator { + public: + MIND_API_BASE_MEMBER(COOTensorGetShape); + /// \brief Constructor. + COOTensorGetShape() : BaseOperator(kNameCOOTensorGetShape) {} +}; +abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_ diff --git a/mindspore/core/ops/coo_tensor_get_values.cc b/mindspore/core/ops/coo_tensor_get_values.cc new file mode 100644 index 00000000000..3bf70599ba0 --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_values.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/coo_tensor_get_values.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto coo_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(coo_tensor->values()); + return coo_tensor->values(); +} +MIND_API_OPERATOR_IMPL(COOTensorGetValues, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetValues, prim::kPrimCOOTensorGetValues, COOTensorGetValuesInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/coo_tensor_get_values.h b/mindspore/core/ops/coo_tensor_get_values.h new file mode 100644 index 00000000000..b251e5ed03f --- /dev/null +++ b/mindspore/core/ops/coo_tensor_get_values.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_ +#define MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCOOTensorGetValues = "COOTensorGetValues"; +class MIND_API COOTensorGetValues : public BaseOperator { + public: + MIND_API_BASE_MEMBER(COOTensorGetValues); + /// \brief Constructor. + COOTensorGetValues() : BaseOperator(kNameCOOTensorGetValues) {} +}; +abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_ diff --git a/mindspore/core/ops/coo_to_csr.cc b/mindspore/core/ops/coo_to_csr.cc new file mode 100644 index 00000000000..e11ac16c13b --- /dev/null +++ b/mindspore/core/ops/coo_to_csr.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/coo_to_csr.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractScalar; +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: the row indices of a sparse coo tensor, and the size of its first dimension. + constexpr auto kCSRArgsSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRArgsSize); + auto row_indices = abstract::CheckArg(op_name, input_args, 0); + auto height = abstract::CheckArg(op_name, input_args, 1); + MS_EXCEPTION_IF_NULL(row_indices); + MS_EXCEPTION_IF_NULL(height); + CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices"); + MS_EXCEPTION_IF_NULL(height->BuildValue()); + ShapeVector out_shape; + if (height->BuildValue()->isa() || height->BuildValue()->isa()) { + int64_t height_value = GetValue(height->BuildValue()); + out_shape.push_back(height_value + 1); + } else { + MS_EXCEPTION(ValueError) << "Currently, only support Integer height."; + } + + MS_EXCEPTION_IF_NULL(row_indices->element()); + auto ret = std::make_shared(row_indices->element()->BuildType(), out_shape); + return ret; +} +MIND_API_OPERATOR_IMPL(COO2CSR, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(COO2CSR, prim::kPrimCOO2CSR, COO2CSRInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/coo_to_csr.h b/mindspore/core/ops/coo_to_csr.h new file mode 100644 index 00000000000..e1fb9eb735f --- /dev/null +++ b/mindspore/core/ops/coo_to_csr.h @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_COO_TO_CSR +#define MINDSPORE_CORE_OPS_COO_TO_CSR +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCOO2CSR = "COO2CSR"; +/// \brief Converts the row indices of a COOTensor to the indptr of a CSRTensor. +class MIND_API COO2CSR : public BaseOperator { + public: + MIND_API_BASE_MEMBER(COO2CSR); + /// \brief Constructor. + COO2CSR() : BaseOperator(kNameCOO2CSR) { InitIOName({"row_indices", "height"}, {"output"}); } +}; +abstract::AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_COO_TO_CSR diff --git a/mindspore/core/ops/csr_elementwise.cc b/mindspore/core/ops/csr_elementwise.cc new file mode 100644 index 00000000000..0dae742a868 --- /dev/null +++ b/mindspore/core/ops/csr_elementwise.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_elementwise.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: a sparse tensor and a dense tensor. + constexpr auto kCSRElementwiseInputsNum = 5; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRElementwiseInputsNum); + auto indptr = abstract::CheckArg(op_name, input_args, 0); + auto indices = abstract::CheckArg(op_name, input_args, 1); + auto values = abstract::CheckArg(op_name, input_args, 2); + auto shape = abstract::CheckArg(op_name, input_args, 3); + auto dense = abstract::CheckArg(op_name, input_args, 4); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(values); + MS_EXCEPTION_IF_NULL(shape); + MS_EXCEPTION_IF_NULL(dense); + + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); + + ShapeVector sparse_shape = ConvertToShapeVector(shape); + auto dense_shape = dense->shape()->shape(); + CheckSparseShape(sparse_shape, dense_shape); + auto ret = values->Broaden(); + // SetAttr + auto nnz_vec = indices->shape()->shape(); + auto csr_avg_rows = nnz_vec[0] / dense_shape[0]; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + return ret; +} +MIND_API_OPERATOR_IMPL(CSRMul, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRMul, prim::kPrimCSRMul, CSRElementWiseInfer, nullptr, true); +MIND_API_OPERATOR_IMPL(CSRDiv, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRDiv, prim::kPrimCSRDiv, CSRElementWiseInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_elementwise.h b/mindspore/core/ops/csr_elementwise.h new file mode 100644 index 00000000000..6b77ee885bb --- /dev/null +++ b/mindspore/core/ops/csr_elementwise.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_ELEMENTWISE +#define MINDSPORE_CORE_OPS_CSR_ELEMENTWISE +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRMul = "CSRMul"; +constexpr auto kNameCSRDiv = "CSRDiv"; +/// \brief CSRTensor elementwise operation. +class MIND_API CSRMul : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRMul); + /// \brief Constructor. + CSRMul() : BaseOperator(kNameCSRMul) { + InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"}); + } +}; + +class MIND_API CSRDiv : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRDiv); + /// \brief Constructor. + CSRDiv() : BaseOperator(kNameCSRDiv) { + InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"}); + } +}; + +abstract::AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_ELEMENTWISE diff --git a/mindspore/core/ops/csr_gather.cc b/mindspore/core/ops/csr_gather.cc new file mode 100644 index 00000000000..cde98be796c --- /dev/null +++ b/mindspore/core/ops/csr_gather.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_gather.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor. + constexpr size_t csr_row_num = 2; + const std::string op_name = primitive->name(); + abstract::CheckArgsSize(op_name, input_args, kSizeFour); + auto indptr = abstract::CheckArg(op_name, input_args, kIndexZero); + auto indices = abstract::CheckArg(op_name, input_args, kIndexOne); + auto dense = abstract::CheckArg(op_name, input_args, kIndexTwo); + auto sparse_shape = abstract::CheckArg(op_name, input_args, kIndexThree); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(dense); + MS_EXCEPTION_IF_NULL(sparse_shape); + + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); + + auto shape_value = sparse_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + auto nnz_vec = indices->shape()->shape(); + int64_t csr_avg_rows = nnz_vec[0] / GetValue(shape_value->value()[0]); + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + + MS_EXCEPTION_IF_NULL(indices->shape()); + ShapeVector out_shape = indices->shape()->shape(); + MS_EXCEPTION_IF_NULL(dense->shape()); + ShapeVector dense_shape = dense->shape()->shape(); + for (size_t i = csr_row_num; i < dense_shape.size(); ++i) { + out_shape.push_back(dense_shape[i]); + } + MS_EXCEPTION_IF_NULL(dense->element()); + auto ret = std::make_shared(dense->element()->BuildType(), out_shape); + return ret; +} +MIND_API_OPERATOR_IMPL(CSRGather, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRGather, prim::kPrimCSRGather, CSRGatherInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_gather.h b/mindspore/core/ops/csr_gather.h new file mode 100644 index 00000000000..5cee9bddd5a --- /dev/null +++ b/mindspore/core/ops/csr_gather.h @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_GATHER +#define MINDSPORE_CORE_OPS_CSR_GATHER +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRGather = "CSRGather"; +/// \brief Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices. +class MIND_API CSRGather : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRGather); + /// \brief Constructor. + CSRGather() : BaseOperator(kNameCSRGather) { InitIOName({"indptr", "indices", "dense", "dense_shape"}, {"output"}); } +}; +abstract::AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_GATHER diff --git a/mindspore/core/ops/csr_mm.cc b/mindspore/core/ops/csr_mm.cc new file mode 100644 index 00000000000..efe273680d1 --- /dev/null +++ b/mindspore/core/ops/csr_mm.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_mm.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: a sparse tensor and a dense tensor. + constexpr auto kCSRMMInputsNum = 5; + constexpr auto kCSRMMShapeSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRMMInputsNum); + auto indptr = abstract::CheckArg(op_name, input_args, 0); + auto indices = abstract::CheckArg(op_name, input_args, 1); + auto values = abstract::CheckArg(op_name, input_args, 2); + auto shape = abstract::CheckArg(op_name, input_args, 3); + auto dense = abstract::CheckArg(op_name, input_args, 4); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(values); + MS_EXCEPTION_IF_NULL(shape); + MS_EXCEPTION_IF_NULL(dense); + + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); + + ShapeVector sparse_shape = ConvertToShapeVector(shape); + auto dense_shape = dense->shape()->shape(); + if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) { + MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! " + << "But csr tensor has " << sparse_shape.size() << " dimensions, " + << "and dense tensor has " << dense_shape.size() << " dimension(s). "; + } + if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) { + MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]" + << ", but dense's shape[0] is: " << dense_shape[kIndexZero] + << " and csr tensor's shape[1] is " << sparse_shape[kIndexOne]; + } + + ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]}; + auto ret = std::make_shared(values->element()->BuildType(), out_shape); + // SetAttr + auto nnz_vec = indices->shape()->shape(); + auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero]; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + return ret; +} +MIND_API_OPERATOR_IMPL(CSRMM, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRMM, prim::kPrimCSRMM, CSRMMInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_mm.h b/mindspore/core/ops/csr_mm.h new file mode 100644 index 00000000000..afe4caf0309 --- /dev/null +++ b/mindspore/core/ops/csr_mm.h @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_MM +#define MINDSPORE_CORE_OPS_CSR_MM +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRMM = "CSRMM"; +/// \brief Sparse matrix-matrix multiplication. +class MIND_API CSRMM : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRMM); + /// \brief Constructor. + CSRMM() : BaseOperator(kNameCSRMM) { + InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"}); + } +}; +abstract::AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_MM diff --git a/mindspore/core/ops/csr_mv.cc b/mindspore/core/ops/csr_mv.cc new file mode 100644 index 00000000000..4be06c41ed9 --- /dev/null +++ b/mindspore/core/ops/csr_mv.cc @@ -0,0 +1,77 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_mv.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + constexpr auto kCSRMVInputsNum = 5; + constexpr auto kCSRMVShapeSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRMVInputsNum); + auto indptr = abstract::CheckArg(op_name, input_args, 0); + auto indices = abstract::CheckArg(op_name, input_args, 1); + auto values = abstract::CheckArg(op_name, input_args, 2); + auto shape = abstract::CheckArg(op_name, input_args, 3); + auto dense = abstract::CheckArg(op_name, input_args, 4); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(values); + MS_EXCEPTION_IF_NULL(shape); + MS_EXCEPTION_IF_NULL(dense); + + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); + + ShapeVector sparse_shape = ConvertToShapeVector(shape); + ShapeVector dense_shape = dense->shape()->shape(); + if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) { + MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! " + << "But csr tensor has " << sparse_shape.size() << " dimensions, " + << "and dense tensor has " << dense_shape.size() << " dimension(s). "; + } + if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) { + MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)" + << ", but its current shape is: " + << "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ")."; + } + + ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]}; + auto ret = std::make_shared(values->element()->BuildType(), out_shape); + // SetAttr + auto nnz_vec = indices->shape()->shape(); + auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero]; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + return ret; +} +MIND_API_OPERATOR_IMPL(CSRMV, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRMV, prim::kPrimCSRMV, CSRMVInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_mv.h b/mindspore/core/ops/csr_mv.h new file mode 100644 index 00000000000..fd22bd90309 --- /dev/null +++ b/mindspore/core/ops/csr_mv.h @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_MV +#define MINDSPORE_CORE_OPS_CSR_MV +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRMV = "CSRMV"; +/// \brief Sparse matrix-vector multiplication. +class MIND_API CSRMV : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRMV); + /// \brief Constructor. + CSRMV() : BaseOperator(kNameCSRMV) { + InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"}); + } +}; +abstract::AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_MV diff --git a/mindspore/core/ops/csr_reducesum.cc b/mindspore/core/ops/csr_reducesum.cc new file mode 100644 index 00000000000..e44590029ba --- /dev/null +++ b/mindspore/core/ops/csr_reducesum.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_reducesum.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractScalar; +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: a sparse tensor and an axis. + constexpr auto kCSRReduceSumInputsNum = 5; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRReduceSumInputsNum); + auto indptr = abstract::CheckArg(op_name, input_args, 0); + auto indices = abstract::CheckArg(op_name, input_args, 1); + auto values = abstract::CheckArg(op_name, input_args, 2); + auto shape = abstract::CheckArg(op_name, input_args, 3); + auto axis = abstract::CheckArg(op_name, input_args, 4); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(values); + MS_EXCEPTION_IF_NULL(shape); + MS_EXCEPTION_IF_NULL(axis); + + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices"); + + ShapeVector sparse_shape = ConvertToShapeVector(shape); + ShapeVector out_shape = sparse_shape; + MS_EXCEPTION_IF_NULL(axis->BuildValue()); + if (axis->BuildValue()->isa() || axis->BuildValue()->isa()) { + int64_t axis_value = GetValue(axis->BuildValue()); + int64_t dim = static_cast(sparse_shape.size()); + if (axis_value != 1 && axis_value != 1 - dim) { + MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value + << "and `1- dim`: " << 1 - dim << "."; + } + if (axis_value < 0) { + axis_value += dim; + } + out_shape[LongToSize(axis_value)] = 1; + primitive->set_attr(kCSRAxis, MakeValue(axis_value)); + } else { + MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got " + << axis->BuildType()->ToString() << "."; + } + + MS_EXCEPTION_IF_NULL(values->element()); + auto ret = std::make_shared(values->element()->BuildType(), out_shape); + // SetAttr + auto nnz_vec = indices->shape()->shape(); + auto csr_avg_rows = nnz_vec[0] / sparse_shape[0]; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + return ret; +} +MIND_API_OPERATOR_IMPL(CSRReduceSum, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRReduceSum, prim::kPrimCSRReduceSum, CSRReduceSumInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_reducesum.h b/mindspore/core/ops/csr_reducesum.h new file mode 100644 index 00000000000..b9d76434830 --- /dev/null +++ b/mindspore/core/ops/csr_reducesum.h @@ -0,0 +1,45 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_REDUCESUM +#define MINDSPORE_CORE_OPS_CSR_REDUCESUM +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRReduceSum = "CSRReduceSum"; +/// \brief CSRTensor reducesum. +class MIND_API CSRReduceSum : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRReduceSum); + /// \brief Constructor. + CSRReduceSum() : BaseOperator(kNameCSRReduceSum) { + InitIOName({"indptr", "indices", "values", "dense_shape", "axis"}, {"output"}); + } +}; +abstract::AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_REDUCESUM diff --git a/mindspore/core/ops/csr_tensor_get_indices.cc b/mindspore/core/ops/csr_tensor_get_indices.cc new file mode 100644 index 00000000000..91572948956 --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_indices.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_tensor_get_indices.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto csr_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(csr_tensor->indices()); + return csr_tensor->indices(); +} +MIND_API_OPERATOR_IMPL(CSRTensorGetIndices, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndices, prim::kPrimCSRTensorGetIndices, CSRTensorGetIndicesInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_tensor_get_indices.h b/mindspore/core/ops/csr_tensor_get_indices.h new file mode 100644 index 00000000000..be95113b04d --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_indices.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_ +#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRTensorGetIndices = "CSRTensorGetIndices"; +/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor. +class MIND_API CSRTensorGetIndices : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRTensorGetIndices); + /// \brief Constructor. + CSRTensorGetIndices() : BaseOperator(kNameCSRTensorGetIndices) {} +}; +abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_ diff --git a/mindspore/core/ops/csr_tensor_get_indptr.cc b/mindspore/core/ops/csr_tensor_get_indptr.cc new file mode 100644 index 00000000000..daffa584985 --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_indptr.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_tensor_get_indptr.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto csr_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(csr_tensor->indptr()); + return csr_tensor->indptr(); +} +MIND_API_OPERATOR_IMPL(CSRTensorGetIndptr, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndptr, prim::kPrimCSRTensorGetIndptr, CSRTensorGetIndptrInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_tensor_get_indptr.h b/mindspore/core/ops/csr_tensor_get_indptr.h new file mode 100644 index 00000000000..019861d7df5 --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_indptr.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_ +#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRTensorGetIndptr = "CSRTensorGetIndptr"; +/// \brief CSRTensorGetIndptr op is used to get indptr in CSRTensor. +class MIND_API CSRTensorGetIndptr : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRTensorGetIndptr); + /// \brief Constructor. + CSRTensorGetIndptr() : BaseOperator(kNameCSRTensorGetIndptr) {} +}; +abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_ diff --git a/mindspore/core/ops/csr_tensor_get_shape.cc b/mindspore/core/ops/csr_tensor_get_shape.cc new file mode 100644 index 00000000000..6ba0b4dd0af --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_shape.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_tensor_get_shape.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto csr_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(csr_tensor->shape()); + return csr_tensor->shape(); +} +MIND_API_OPERATOR_IMPL(CSRTensorGetShape, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetShape, prim::kPrimCSRTensorGetDenseShape, CSRTensorGetShapeInfer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_tensor_get_shape.h b/mindspore/core/ops/csr_tensor_get_shape.h new file mode 100644 index 00000000000..b0caf0b697e --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_shape.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_ +#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRTensorGetShape = "CSRTensorGetShape"; +class MIND_API CSRTensorGetShape : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRTensorGetShape); + /// \brief Constructor. + CSRTensorGetShape() : BaseOperator(kNameCSRTensorGetShape) {} +}; +abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_ diff --git a/mindspore/core/ops/csr_tensor_get_values.cc b/mindspore/core/ops/csr_tensor_get_values.cc new file mode 100644 index 00000000000..a43af0fb00e --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_values.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_tensor_get_values.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + auto csr_tensor = InferSparseAttr(primitive, args_spec_list); + MS_EXCEPTION_IF_NULL(csr_tensor->values()); + return csr_tensor->values(); +} +MIND_API_OPERATOR_IMPL(CSRTensorGetValues, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetValues, prim::kPrimCSRTensorGetValues, CSRTensorGetValuesInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_tensor_get_values.h b/mindspore/core/ops/csr_tensor_get_values.h new file mode 100644 index 00000000000..84e20020207 --- /dev/null +++ b/mindspore/core/ops/csr_tensor_get_values.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_ +#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSRTensorGetValues = "CSRTensorGetValues"; +/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor. +class MIND_API CSRTensorGetValues : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSRTensorGetValues); + /// \brief Constructor. + CSRTensorGetValues() : BaseOperator(kNameCSRTensorGetValues) {} +}; +abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_ diff --git a/mindspore/core/ops/csr_to_coo.cc b/mindspore/core/ops/csr_to_coo.cc new file mode 100644 index 00000000000..18b5ce7f716 --- /dev/null +++ b/mindspore/core/ops/csr_to_coo.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/csr_to_coo.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractScalar; +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + // Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements. + constexpr auto kCSRArgsSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, input_args, kCSRArgsSize); + auto indptr = abstract::CheckArg(op_name, input_args, 0); + CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr"); + + auto nnz = abstract::CheckArg(op_name, input_args, 1); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(nnz); + + MS_EXCEPTION_IF_NULL(nnz->BuildValue()); + ShapeVector out_shape; + if (nnz->BuildValue()->isa() || nnz->BuildValue()->isa()) { + int64_t nnz_value = GetValue(nnz->BuildValue()); + out_shape.push_back(nnz_value); + } else { + MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz."; + } + + MS_EXCEPTION_IF_NULL(indptr->shape()); + auto num_rows = indptr->shape()->shape()[0] - 1; + int csr_avg_rows = GetValue(nnz->BuildValue()) / num_rows; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + + MS_EXCEPTION_IF_NULL(indptr->element()); + auto ret = std::make_shared(indptr->element()->BuildType(), out_shape); + return ret; +} +MIND_API_OPERATOR_IMPL(CSR2COO, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(CSR2COO, prim::kPrimCSR2COO, CSR2COOInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/csr_to_coo.h b/mindspore/core/ops/csr_to_coo.h new file mode 100644 index 00000000000..752472fbf48 --- /dev/null +++ b/mindspore/core/ops/csr_to_coo.h @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_CSR_TO_COO +#define MINDSPORE_CORE_OPS_CSR_TO_COO +#include +#include +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameCSR2COO = "CSR2COO"; +/// \brief Converts the indptr of a CSRTensor to the row indices of a COOTensor. +class MIND_API CSR2COO : public BaseOperator { + public: + MIND_API_BASE_MEMBER(CSR2COO); + /// \brief Constructor. + CSR2COO() : BaseOperator(kNameCSR2COO) { InitIOName({"indptr", "nnz"}, {"output"}); } +}; +abstract::AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_CSR_TO_COO diff --git a/mindspore/core/ops/make_cootensor.cc b/mindspore/core/ops/make_cootensor.cc new file mode 100644 index 00000000000..0b094818e0d --- /dev/null +++ b/mindspore/core/ops/make_cootensor.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "ops/make_cootensor.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, kSizeThree); + auto indices = abstract::CheckArg(op_name, args_spec_list, kIndexZero); + auto values = abstract::CheckArg(op_name, args_spec_list, kIndexOne); + auto dense_shape = abstract::CheckArg(op_name, args_spec_list, kIndexTwo); + + auto indices_dtype = indices->element()->BuildType(); + CheckSparseIndicesDtype(indices_dtype, "Indices"); + + auto indices_shp = indices->shape()->shape(); + CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices"); + + auto values_shp = values->shape()->shape(); + CheckSparseShape(values_shp.size(), kSizeOne, "Values"); + + if (indices_shp[kIndexZero] != values_shp[kIndexZero]) { + MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape[" + << kIndexZero << "]`, but got `indices.shape[" << kIndexZero + << "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero + << "]`: " << values_shp[kIndexZero]; + } + constexpr int64_t kDimTwo = 2; + if (indices_shp[kIndexOne] != kDimTwo) { + MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got " + << indices_shp[kIndexOne]; + } + + for (const auto &elem_type : dense_shape->ElementsType()) { + if (!elem_type->isa()) { + MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got " + << elem_type->ToString(); + } + } + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + auto min_elem = *std::min_element(std::begin(shp), std::end(shp)); + if (min_elem <= 0) { + MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem + << "int it"; + } + ShapeVector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int64_t { + auto elem = GetValue(e); + return elem; + }); + if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) { + MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second " + << "dimension of `indices`: " << dense_shape_vec.size() << " but got " + << indices_shp[kIndexOne]; + } + for (auto dense_shape_elem : dense_shape_vec) { + if (dense_shape_elem <= 0) { + MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got " + << dense_shape_value->ToString(); + } + } + AbstractBasePtrList element_list{indices, values, dense_shape}; + return std::make_shared(element_list); +} +MIND_API_OPERATOR_IMPL(MakeCOOTensor, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(MakeCOOTensor, prim::kPrimMakeCOOTensor, MakeCOOTensorInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/make_cootensor.h b/mindspore/core/ops/make_cootensor.h new file mode 100644 index 00000000000..3fb05f71342 --- /dev/null +++ b/mindspore/core/ops/make_cootensor.h @@ -0,0 +1,39 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_ +#define MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMakeCOOTensor = "MakeCOOTensor"; +class MIND_API MakeCOOTensor : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MakeCOOTensor); + /// \brief Constructor. + MakeCOOTensor() : BaseOperator(kNameMakeCOOTensor) {} +}; +abstract::AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_ diff --git a/mindspore/core/ops/make_csrtensor.cc b/mindspore/core/ops/make_csrtensor.cc new file mode 100644 index 00000000000..e69b47255fa --- /dev/null +++ b/mindspore/core/ops/make_csrtensor.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "ops/make_csrtensor.h" + +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "utils/anf_utils.h" +#include "utils/check_convert_utils.h" +#include "utils/tensor_construct_utils.h" + +namespace mindspore { +namespace ops { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list) { + // Inputs: three tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, kSizeFour); + auto indptr = abstract::CheckArg(op_name, args_spec_list, kIndexZero); + auto indices = abstract::CheckArg(op_name, args_spec_list, kIndexOne); + auto values = abstract::CheckArg(op_name, args_spec_list, kIndexTwo); + auto shape = abstract::CheckArg(op_name, args_spec_list, kIndexThree); + + auto indptr_dtype = indptr->element()->BuildType(); + auto indices_dtype = indices->element()->BuildType(); + CheckSparseIndicesDtype(indptr_dtype, "indptr"); + CheckSparseIndicesDtype(indices_dtype, "indices"); + + auto indptr_shp = indptr->shape()->shape(); + CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr"); + + auto indices_shp = indices->shape()->shape(); + CheckSparseShape(indices_shp.size(), kSizeOne, "Indices"); + + auto values_shp = values->shape()->shape(); + if (indices_shp[kIndexZero] != values_shp[kIndexZero]) { + MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: " + << values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero]; + } + + auto shape_value = shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + auto shp = shape_value->value(); + ShapeVector shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t { + auto elem = GetValue(e); + return elem; + }); + if (values_shp.size() + 1 != shape_vec.size()) { + MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got" + << "Values' dimension: " << values_shp.size() + << ", csr_tensor's dimension: " << shape_vec.size() << "."; + } + if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) { + MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero]; + } + size_t shape_size = 1; + auto shape_types = shape->ElementsType(); + for (size_t i = 0; i < shape_vec.size(); ++i) { + if (shape_vec[i] <= 0) { + MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString(); + } + if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) { + MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape."; + } + if (!shape_types[i]->isa()) { + MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString(); + } + shape_size *= LongToSize(shape_vec[i]); + } + if (static_cast(shape_size) < values_shp[kIndexZero]) { + MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero] + << " non-zero values."; + } + std::vector element_list{indptr, indices, values, shape}; + return std::make_shared(element_list); +} +MIND_API_OPERATOR_IMPL(MakeCSRTensor, BaseOperator); +REGISTER_PRIMITIVE_EVAL_IMPL(MakeCSRTensor, prim::kPrimMakeCSRTensor, MakeCSRTensorInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/make_csrtensor.h b/mindspore/core/ops/make_csrtensor.h new file mode 100644 index 00000000000..259f863fef3 --- /dev/null +++ b/mindspore/core/ops/make_csrtensor.h @@ -0,0 +1,40 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_ +#define MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_ + +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMakeCSRTensor = "MakeCSRTensor"; +/// \brief MakeCSRTensor op is used to construct CSRTensor. +class MIND_API MakeCSRTensor : public BaseOperator { + public: + MIND_API_BASE_MEMBER(MakeCSRTensor); + /// \brief Constructor. + MakeCSRTensor() : BaseOperator(kNameMakeCSRTensor) {} +}; +abstract::AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &args_spec_list); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_ diff --git a/mindspore/core/ops/op_utils.cc b/mindspore/core/ops/op_utils.cc index 2611c2d41d2..6e8fa01bdf2 100644 --- a/mindspore/core/ops/op_utils.cc +++ b/mindspore/core/ops/op_utils.cc @@ -466,5 +466,99 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) { return EvalShapeTensorValue(prim, args, true); } + +void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) { + constexpr auto kCSRMulBatchPos = 2; + int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size()); + if (dlen < 0) { + MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, " + << "but sparse tensor has " << sparse_shp.size() << " dimensions, " + << "and dense tensor has " << dense_shp.size() << " dimensions. "; + } + for (int i = 0; i < dlen; i++) { + (void)dense_shp.insert(dense_shp.begin(), 1); + } + if (sparse_shp.size() != dense_shp.size()) { + MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size()."; + } + if (sparse_shp.size() < 1) { + MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero."; + } + for (size_t i = 0; i < sparse_shp.size(); i++) { + auto s = sparse_shp[i]; + auto d = dense_shp[i]; + if (i < kCSRMulBatchPos) { + if (d != s && d != 1) { + MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape."; + } + } else { + if (d != s) { + MS_EXCEPTION(ValueError) << "Currently, sparse shape and dense shape must equal in feature dimensions."; + } + } + } +} + +void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) { + if (shape_size != expected_dim) { + MS_EXCEPTION(ValueError) << arg_name << " must be a " << expected_dim << "-dimensional tensor, but got a " + << shape_size << "-dimensional tensor."; + } +} + +void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name) { + if (!(data_type->equal(kInt16) || data_type->equal(kInt32) || data_type->equal(kInt64))) { + MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got " + << data_type->ToString() << "."; + } +} + +void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name) { + if (!data_type->equal(kInt32)) { + MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got " + << data_type->ToString() << "."; + } +} + +ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape) { + auto shape_value = shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + ShapeVector shape_vec; + (void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec), + [](const ValuePtr &e) -> int64_t { + auto elem = GetValue(e); + return elem; + }); + return shape_vec; +} + +template +std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + constexpr size_t kSizeExpect = 1; + if (args_spec_list.size() != kSizeExpect) { + MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the number of input should be " << kSizeExpect + << ", but got " << args_spec_list.size() << "."; + } + constexpr size_t kIndex = 0; + auto abs = args_spec_list[kIndex]; + MS_EXCEPTION_IF_NULL(abs); + // To avoid AbstractSparseTensors being generalized to AbstractTuple. + if (dyn_cast(abs) == nullptr) { + auto abs_tuple = dyn_cast(abs); + if (abs_tuple != nullptr) { + return std::make_shared(abs_tuple->elements()); + } + } else if (dyn_cast(abs) != nullptr) { + return dyn_cast(abs); + } + MS_EXCEPTION(TypeError) << "For \'" << primitive->name() << "\', input[" << kIndex + << "] should be AbstractSparseTensor or AbstractTuple, but got " + << abs->BuildType()->ToString() << "."; +} +template std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +template std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 3dea0dddac0..a29a940570c 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -80,5 +80,23 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP // Infer shape value of compute-shape op that could change the dim value, e.g. Mul, Add, Sub // Do not support op with multiple outputs for now ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args); + +void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp); + +void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name); + +void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name); + +void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name); + +ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape); + +template +std::shared_ptr InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); + +constexpr auto kCSRAvgRows = "csr_avg_rows"; +constexpr auto kIsCSR = "is_csr"; +constexpr auto kCSRDenseShape = "dense_shape"; +constexpr auto kCSRAxis = "axis"; } // namespace mindspore::ops #endif // MINDSPORE_CORE_OPS_OP_UTILS_H diff --git a/mindspore/core/ops/sparse_matrix_add.cc b/mindspore/core/ops/sparse_matrix_add.cc index 18bcf39f7df..bad896d0900 100644 --- a/mindspore/core/ops/sparse_matrix_add.cc +++ b/mindspore/core/ops/sparse_matrix_add.cc @@ -46,21 +46,6 @@ constexpr size_t kAlphaIndex = 10; constexpr size_t kBetaIndex = 11; constexpr int64_t kDefaultRank = 2; constexpr int64_t kBatchedRank = 3; - -inline void CheckSparseShape(const size_t sparse_shape_size, const size_t expected_dim, const std::string &arg_name) { - if (sparse_shape_size != expected_dim) { - MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim - << "-dimensional tensor, but got a " << sparse_shape_size - << "-dimensional tensor."; - } -} - -inline void CheckSparseIndicesDtype(const mindspore::TypePtr dtype, const std::string &arg_name) { - if (!(dtype->equal(mindspore::kInt16) || dtype->equal(mindspore::kInt32) || dtype->equal(mindspore::kInt64))) { - MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got " - << dtype->ToString() << "."; - } -} } // namespace void SparseMatrixAdd::set_dense_shape(const std::vector &shape) { (void)this->AddAttr(kDenseShape, api::MakeValue(shape));