From 080ad981d6f2355d51d7b4c62ffc2d96161d86b9 Mon Sep 17 00:00:00 2001 From: huangmengxi Date: Mon, 14 Feb 2022 10:08:13 +0800 Subject: [PATCH] add csr bprop && csr method --- .../common/optimizer/const_input_to_attr.cc | 3 + .../backend/common/pass/sparse_process.cc | 1 - .../pipeline/jit/parse/data_converter.cc | 3 + mindspore/ccsrc/pipeline/jit/pipeline.cc | 4 +- mindspore/ccsrc/pipeline/jit/resource.cc | 12 +- mindspore/ccsrc/utils/convert_utils.cc | 3 +- mindspore/ccsrc/utils/convert_utils.h | 9 +- mindspore/core/abstract/infer_functions.h | 6 + mindspore/core/abstract/prim_others.cc | 99 ++++++++++++- .../core/abstract/primitive_infer_map.cc | 3 + mindspore/core/base/core_ops.h | 6 + .../_extends/parse/standard_method.py | 28 +++- mindspore/python/mindspore/common/api.py | 2 +- mindspore/python/mindspore/common/tensor.py | 27 ++++ mindspore/python/mindspore/nn/cell.py | 21 +-- .../python/mindspore/ops/_grad/grad_sparse.py | 79 ++++++++++ .../ops/_op_impl/akg/gpu/__init__.py | 3 + .../mindspore/ops/_op_impl/akg/gpu/coo2csr.py | 29 ++++ .../mindspore/ops/_op_impl/akg/gpu/csr2coo.py | 29 ++++ .../ops/_op_impl/akg/gpu/csr_gather.py | 33 +++++ .../mindspore/ops/_op_impl/akg/gpu/csr_mv.py | 3 - .../ops/composite/multitype_ops/add_impl.py | 15 ++ .../composite/multitype_ops/ones_like_impl.py | 6 + .../multitype_ops/zeros_like_impl.py | 14 ++ mindspore/python/mindspore/ops/functional.py | 9 ++ .../mindspore/ops/operations/_csr_ops.py | 135 ++++++++++++++++++ tests/st/sparse/test_coo.py | 36 +++++ tests/st/sparse/test_csr.py | 118 ++++++++++++++- 28 files changed, 709 insertions(+), 27 deletions(-) create mode 100644 mindspore/python/mindspore/ops/_op_impl/akg/gpu/coo2csr.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr2coo.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_gather.py diff --git a/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr.cc b/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr.cc index 90761d7ef8a..bc916c5a59d 100644 --- a/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr.cc @@ -58,6 +58,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimUnsortedSegmentMin->name(), {2}); Register(prim::kPrimUnsortedSegmentMax->name(), {2}); Register(prim::kPrimCSRReduceSum->name(), {1}); + Register(prim::kPrimCSRGather->name(), {3}); + Register(prim::kPrimCSR2COO->name(), {1}); + Register(prim::kPrimCOO2CSR->name(), {1}); Register(kSparseGatherV2OpName, {2}); Register(kUnsortedSegmentProdOpName, {2}); Register(kSimpleMeanGradOpName, {1}); diff --git a/mindspore/ccsrc/backend/common/pass/sparse_process.cc b/mindspore/ccsrc/backend/common/pass/sparse_process.cc index 483353fd83e..c48500467d1 100644 --- a/mindspore/ccsrc/backend/common/pass/sparse_process.cc +++ b/mindspore/ccsrc/backend/common/pass/sparse_process.cc @@ -191,7 +191,6 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An } auto new_node = cnode->func_graph()->NewCNode(new_inputs); new_node->set_abstract(node->abstract()); - AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node); return new_node; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index ec2e00fb964..2faf5fb9fd4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -40,6 +40,8 @@ using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; using CSRTensor = mindspore::tensor::CSRTensor; using CSRTensorPtr = mindspore::tensor::CSRTensorPtr; +using COOTensor = mindspore::tensor::COOTensor; +using COOTensorPtr = mindspore::tensor::COOTensorPtr; using InstanceCheckFunc = std::function; using InstanceConvertFunc = std::function; @@ -489,6 +491,7 @@ std::vector GetDataConverters() { std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), + std::make_shared>(ObjCast), std::make_shared>(ConvertTuple), std::make_shared>(ConvertList), std::make_shared>(PyCast), diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 3905c6728eb..894f0b24b09 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -98,6 +98,7 @@ namespace pipeline { using Tensor = mindspore::tensor::Tensor; using MetaTensor = mindspore::tensor::MetaTensor; using CSRTensor = mindspore::tensor::CSRTensor; +using COOTensor = mindspore::tensor::COOTensor; using TensorOrderMap = std::map>; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensorPtr; @@ -178,7 +179,8 @@ bool CheckArgValid(const py::handle &arg) { return py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg) || - ((py::isinstance(arg) || py::isinstance(arg)) && !py::hasattr(arg, "__parameter__")); + ((py::isinstance(arg) || py::isinstance(arg) || py::isinstance(arg)) && + !py::hasattr(arg, "__parameter__")); } std::string GetCompileExceptionInfo() { diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 09b59f8529a..2ee8d268e0a 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -216,7 +216,17 @@ BuiltInTypeMap &GetMethodMap() { }}, {kObjectTypeJTagged, {}}, {kObjectTypeSymbolicKeyType, {}}, - {kObjectTypeEnvType, {}}}; + {kObjectTypeEnvType, {}}, + {kObjectTypeCOOTensorType, + { + {"to_csr", std::string("coo_to_csr")}, + {"to_dense", std::string("coo_to_dense")}, + }}, + {kObjectTypeCSRTensorType, + { + {"to_coo", std::string("csr_to_coo")}, + {"to_dense", std::string("csr_to_dense")}, + }}}; return method_map; } diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 1b7bcaa2e4f..dcb6f8bee99 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -320,7 +320,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) { bool IsCustomCSROP(const AnfNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV}; + const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV, + prim::kPrimCSRGather, prim::kPrimCSR2COO, prim::kPrimCOO2CSR}; return IsOneOfPrimitiveCNode(cnode, prims); } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index 728330d6fe6..7c55e360657 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -94,8 +94,13 @@ const mindspore::HashSet make_sparse_set = {{prim::kMakeCSRTensor}, // sparse_op_set records all sparse_compute operators, which takes sparsetensor // and (possibly) dense tensors, used in backend common optimization pass: // sparse_process.cc -const mindspore::HashSet sparse_op_set = { - {prim::kSparseTensorDenseMatmul}, {prim::kCSRDenseMul}, {prim::kCSRReduceSum}, {prim::kCSRMV}, {prim::kCSRMul}}; +const mindspore::HashSet sparse_op_set = {{prim::kSparseTensorDenseMatmul}, + {prim::kCSRDenseMul}, + {prim::kCSRReduceSum}, + {prim::kCSRMV}, + {prim::kCSRMul}, + {prim::kCSRGather}, + {prim::kCSR2COO}}; bool IsCustomCSROP(const AnfNodePtr &cnode); } // namespace mindspore diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index dd481bf4551..012b64d7de7 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -165,6 +165,12 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index a6ee0ddebf8..9b8ec67633b 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -36,6 +36,7 @@ namespace abstract { constexpr auto kCSRDenseShape = "dense_shape"; constexpr auto kCSRAxis = "axis"; constexpr auto kCSRAvgRows = "csr_avg_rows"; +constexpr auto kIsCSR = "is_csr"; AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // An object of a subclass of AbstractBase @@ -439,6 +440,9 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p << "but sparse tensor has " << sparse_shape.size() << " dimensions, " << "and dense tensor has " << dense_shape.size() << " dimensions, "; } + if (dense_shape[0] != sparse_shape[0]) { + MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!"; + } auto ret = sparse->values()->Broaden(); MS_EXCEPTION_IF_NULL(sparse->indices()->shape()); @@ -446,7 +450,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); - + primitive->set_attr(kIsCSR, MakeValue(true)); return ret; } @@ -482,7 +486,7 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); - + primitive->set_attr(kIsCSR, MakeValue(true)); return ret; } @@ -532,7 +536,98 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); + primitive->set_attr(kIsCSR, MakeValue(true)); + return ret; +} +AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor. + constexpr auto kCSRShapeSize = 2; + constexpr auto kCSRArgsSize = 4; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, kCSRArgsSize); + auto indptr = CheckArg(op_name, args_spec_list, 0); + auto indices = CheckArg(op_name, args_spec_list, 1); + auto dense = CheckArg(op_name, args_spec_list, 2); + auto sparse_shape = CheckArg(op_name, args_spec_list, 3); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(dense); + MS_EXCEPTION_IF_NULL(sparse_shape); + + if (sparse_shape->size() != kCSRShapeSize) { + MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!" + << "But sparse tensor has " << sparse_shape->size() << " dimensions."; + } + + auto shape_value = sparse_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(shape_value); + auto nnz_vec = indices->shape()->shape(); + int64_t csr_avg_rows = nnz_vec[0] / GetValue(shape_value->value()[0]); + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + + MS_EXCEPTION_IF_NULL(indices->shape()); + ShapeVector out_shape = indices->shape()->shape(); + MS_EXCEPTION_IF_NULL(dense->element()); + auto ret = std::make_shared(dense->element()->BuildType(), out_shape); + return ret; +} + +AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements. + constexpr auto kCSRArgsSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, kCSRArgsSize); + auto indptr = CheckArg(op_name, args_spec_list, 0); + auto nnz = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(indptr); + MS_EXCEPTION_IF_NULL(nnz); + + MS_EXCEPTION_IF_NULL(nnz->BuildValue()); + ShapeVector out_shape; + if (nnz->BuildValue()->isa() || nnz->BuildValue()->isa()) { + int64_t nnz_value = GetValue(nnz->BuildValue()); + out_shape.push_back(nnz_value); + } else { + MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz."; + } + + MS_EXCEPTION_IF_NULL(indptr->shape()); + auto num_rows = indptr->shape()->shape()[0] - 1; + int csr_avg_rows = GetValue(nnz->BuildValue()) / num_rows; + primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); + primitive->set_attr(kIsCSR, MakeValue(true)); + + MS_EXCEPTION_IF_NULL(indptr->element()); + auto ret = std::make_shared(indptr->element()->BuildType(), out_shape); + return ret; +} + +AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: the row indices of a sparse coo tensor, and the size of its first dimension. + constexpr auto kCSRArgsSize = 2; + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, kCSRArgsSize); + auto row_indices = CheckArg(op_name, args_spec_list, 0); + auto height = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(row_indices); + MS_EXCEPTION_IF_NULL(height); + + MS_EXCEPTION_IF_NULL(height->BuildValue()); + ShapeVector out_shape; + if (height->BuildValue()->isa() || height->BuildValue()->isa()) { + int64_t height_value = GetValue(height->BuildValue()); + out_shape.push_back(height_value + 1); + } else { + MS_EXCEPTION(ValueError) << "Currently, only support Integer height."; + } + + MS_EXCEPTION_IF_NULL(row_indices->element()); + auto ret = std::make_shared(row_indices->element()->BuildType(), out_shape); return ret; } diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 80c8dc813c3..5ac6282821b 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -236,6 +236,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}}, {prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}}, {prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}}, + {prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}}, + {prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}}, + {prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}}, // Comm Ops {prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}}, {prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 1d11e894ecf..728a9071d7a 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -164,6 +164,9 @@ constexpr auto kCSRDenseMul = "CSRDenseMul"; constexpr auto kCSRReduceSum = "CSRReduceSum"; constexpr auto kCSRMV = "CSRMV"; constexpr auto kCSRMul = "CSRMul"; +constexpr auto kCSRGather = "CSRGather"; +constexpr auto kCSR2COO = "CSR2COO"; +constexpr auto kCOO2CSR = "COO2CSR"; // Meta Function Graph constexpr auto kJ = "J"; @@ -606,6 +609,9 @@ GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared(kCSRDenseMu GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared(kCSRReduceSum)); GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared(kCSRMV)); GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared(kCSRMul)); +GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared(kCSRGather)); +GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared(kCSR2COO)); +GVAR_DEF(PrimitivePtr, kPrimCOO2CSR, std::make_shared(kCOO2CSR)); // TensorList GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared("TensorListFromTensor")); diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index ed0f4d7cf9e..5a66dcec345 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -18,7 +18,7 @@ from dataclasses import dataclass -from mindspore import Tensor, Parameter +from mindspore import Tensor, Parameter, CSRTensor, COOTensor from mindspore import dtype as mstype from ..._checkparam import Validator as validator @@ -1573,6 +1573,32 @@ def while_cond(x): return x +def coo_to_csr(x): + row_indices = x.indices[:, 0] + col_indices = x.indices[:, 1] + idx_dtype = x.indices.dtype + row_indices, sort_idx = F.sort(row_indices.astype(mstype.float32)) + row_indices = row_indices.astype(idx_dtype) + col_indices = col_indices[sort_idx] + values = x.values[sort_idx] + indptr = F.coo2csr(row_indices, x.shape[0]) + return CSRTensor(indptr, col_indices, values, x.shape) + + +def coo_to_dense(x): + zeros_tensor = F.zeros(x.shape, x.values.dtype) + return F.tensor_scatter_update(zeros_tensor, x.indices, x.values) + +def csr_to_coo(x): + row_indices = F.csr2coo(x.indptr, x.values.shape[0]) + coo_indices = P.Stack(1)((row_indices, x.indices)) + return COOTensor(coo_indices, x.values, x.shape) + +def csr_to_dense(x): + coo_tensor = x.to_coo() + return coo_tensor.to_dense() + + @constexpr def empty_tensor(dtype): return Tensor([], dtype) diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 3a37b835a93..cc92e12620e 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -290,7 +290,7 @@ class _MindsporeFunctionExecutor: return None new_inputs = [] for i in args_list: - if isinstance(i, (Tensor, CSRTensor)): + if isinstance(i, (Tensor, CSRTensor, COOTensor)): new_inputs.append(i) elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): new_inputs.append(i) diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index 3a8851af87d..4e4d69648d0 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -2468,6 +2468,24 @@ class COOTensor(COOTensor_): def shape(self): return self._shape + def to_csr(self): + "Converts COOTensor to CSRTensor." + row_indices = self.indices[:, 0] + col_indices = self.indices[:, 1] + idx_dtype = self.indices.dtype + row_indices, sort_idx = tensor_operator_registry.get("sort")( + row_indices.astype(mstype.float32)) + row_indices = row_indices.astype(idx_dtype) + col_indices = col_indices[sort_idx] + values = self.values[sort_idx] + indptr = tensor_operator_registry.get("coo2csr")(row_indices, self.shape[0]) + return CSRTensor(indptr, col_indices, values, self.shape) + + def to_dense(self): + zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype) + return tensor_operator_registry.get("tensor_scatter_update")( + zeros_tensor, self.indices, self.values) + class CSRTensor(CSRTensor_): """ @@ -2566,6 +2584,15 @@ class CSRTensor(CSRTensor_): def to_tuple(self): return self.indptr, self.indices, self.values, self.shape + def to_coo(self): + row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0]) + coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices)) + return COOTensor(coo_indices, self.values, self.shape) + + def to_dense(self): + coo_tensor = self.to_coo() + return coo_tensor.to_dense() + def _vm_compare(*args): """Implement `vm_compare` for tensor.""" diff --git a/mindspore/python/mindspore/nn/cell.py b/mindspore/python/mindspore/nn/cell.py index d1217f9574b..52a5663fb55 100755 --- a/mindspore/python/mindspore/nn/cell.py +++ b/mindspore/python/mindspore/nn/cell.py @@ -32,7 +32,7 @@ from .._checkparam import Validator from ..common import dtype as mstype from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.parameter import Parameter, ParameterTuple -from ..common.tensor import Tensor, CSRTensor +from ..common.tensor import Tensor, CSRTensor, COOTensor from ..ops.operations import Cast from ..ops.primitive import Primitive from ..ops.operations import _inner_ops as inner @@ -815,6 +815,8 @@ class Cell(Cell_): if i.has_init: i.init_data() new_inputs.append(i) + elif isinstance(i, COOTensor): + new_inputs.append(i) elif isinstance(i, CSRTensor): new_inputs.append(i) elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): @@ -1267,16 +1269,16 @@ class Cell(Cell_): def _add_mixed_precision_flag(self, **flags): """Add mixed precision flag to current cell""" - if "fp16" in flags and flags["fp16"]: + if "fp16" in flags and flags.get("fp16", False): Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16) - if "fp32" in flags and flags["fp32"]: + if "fp32" in flags and flags.get("fp32", False): Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32) def _add_mixed_precision_flag_recursive(self, **flags): """Add mixed precision flag to each cell""" - if "fp16" in flags and flags["fp16"]: + if "fp16" in flags and flags.get("fp16", False): self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16) - if "fp32" in flags and flags["fp32"]: + if "fp32" in flags and flags.get("fp32", False): self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32) def add_flags(self, **flags): @@ -1876,15 +1878,16 @@ class Cell(Cell_): """ self._recompute() if 'mp_comm_recompute' in kwargs.keys(): - self._mp_comm_recompute(kwargs['mp_comm_recompute']) + self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False)) if 'parallel_optimizer_comm_recompute' in kwargs.keys(): - if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1: + if (kwargs.get('parallel_optimizer_comm_recompute', False) and + context.get_auto_parallel_context("pipeline_stages") > 1): logger.warning("Currently, the communication operator allgathers introduced by optimizer shard " "are not support recomputation in pipeline parallel.") elif context.get_auto_parallel_context("pipeline_stages") == 1: - self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute']) + self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False)) if 'recompute_slice_activation' in kwargs.keys(): - self._recompute_slice_activation(kwargs['recompute_slice_activation']) + self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False)) for key, _ in kwargs.items(): if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'): diff --git a/mindspore/python/mindspore/ops/_grad/grad_sparse.py b/mindspore/python/mindspore/ops/_grad/grad_sparse.py index 48c36f252fd..9855b7dc0e2 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_sparse.py +++ b/mindspore/python/mindspore/ops/_grad/grad_sparse.py @@ -14,8 +14,10 @@ # ============================================================================ """bprop primitives""" +from ...common import dtype as mstype from .. import functional as F from .. import operations as P +from ..operations import _csr_ops from ..composite.multitype_ops.zeros_like_impl import zeros_like from .grad_base import bprops, bprop_getters @@ -78,3 +80,80 @@ def get_bprop_sparse_tensor_dense_matmul(self): values_grad = F.reduce_sum(parts_a * parts_b, 1) return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad return bprop + +@bprop_getters.register(_csr_ops.CSRReduceSum) +def get_bprop_csr_reduce_sum(self): + "Back-propagation for CSRReduceSum." + def bprop(csr_tensor, axis, out, dout): + indptr = csr_tensor.indptr + indices = csr_tensor.indices + shape = csr_tensor.shape + + output_shape_kept_dims = F.reduced_shape(shape, axis) + tile_scaling = F.tuple_div(shape, output_shape_kept_dims) + values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling) + values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape) + return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis) + return bprop + +@bprop_getters.register(_csr_ops.CSRMV) +def get_bprop_csr_mv(self): + "Back-propagation for CSRMV." + def bprop(csr_tensor, dense, out, dout): + indptr = F.csr_tensor_get_indptr(csr_tensor) + indices = F.csr_tensor_get_indices(csr_tensor) + values = F.csr_tensor_get_values(csr_tensor) + dense_shape = csr_tensor.shape + + rows = F.csr2coo(indptr, indices.shape[0]) + idx_dtype = rows.dtype + rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32)) + rows_transposed = rows_transposed.astype(idx_dtype) + cols_transposed = rows[cols_indexing] + values_transposed = values[cols_indexing] + indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1]) + csr_tensor_transposed = F.make_csr_tensor( + indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0])) + + dense_grad = F.csr_mv(csr_tensor_transposed, dout) + parts_a = F.gather(dout, rows, 0) + parts_b = F.gather(dense, indices, 0) + values_grad = F.reduce_sum(parts_a * parts_b, 1) + return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad + return bprop + +@bprop_getters.register(_csr_ops.CSRMul) +def get_bprop_csr_mul(self): + "Back-propagation for CSRMul." + def bprop(csr_tensor, dense, out, dout): + indptr = csr_tensor.indptr + indices = csr_tensor.indices + values = csr_tensor.values + shape = csr_tensor.shape + + csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense) + csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape) + dense_grad_value = F.mul(dout, values) + dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape) + if len(dense.shape) == 1 or dense.shape[0] == 1: + dense_grad = F.csr_reduce_sum(dense_grad, 0) + elif dense.shape[1] == 1: + dense_grad = F.csr_reduce_sum(dense_grad, 1) + else: + row = F.csr2coo(indptr, indices.shape[0]) + coo_idx = P.Stack(-1)((row, indices)) + dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value) + return csr_tensor_grad, dense_grad + return bprop + +@bprop_getters.register(_csr_ops.CSR2COO) +def get_bprop_csr2coo(self): + def bprop(indptr, nnz, out, dout): + return zeros_like(dout) + return bprop + +@bprop_getters.register(_csr_ops.COO2CSR) +def get_bprop_coo2csr(self): + def bprop(row_indices, height, out, dout): + return zeros_like(dout) + return bprop diff --git a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/__init__.py index 9dc37093fa7..625d9ac8d9c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -25,4 +25,7 @@ from .notequal import _notequal_akg from .csr_reduce_sum import _csr_reduce_sum_akg from .csr_mv import _csr_mv_akg from .csr_mul import _csr_mul_akg +from .csr_gather import _csr_gather_akg +from .csr2coo import _csr2coo_akg +from .coo2csr import _coo2csr_akg # Please insert op register in lexicographical order of the filename. diff --git a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/coo2csr.py b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/coo2csr.py new file mode 100644 index 00000000000..1ff6a9d7e92 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/coo2csr.py @@ -0,0 +1,29 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""COO2CSR op""" +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType + +coo2csr_op_info = AkgGpuRegOp("COO2CSR") \ + .fusion_type("OPAQUE") \ + .input(0, "row_indices") \ + .output(0, "output") \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(coo2csr_op_info) +def _coo2csr_akg(): + """COO2CSR AutoDiff register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr2coo.py b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr2coo.py new file mode 100644 index 00000000000..ff0765e1ff3 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr2coo.py @@ -0,0 +1,29 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CSR2COO op""" +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType + +csr2coo_op_info = AkgGpuRegOp("CSR2COO") \ + .fusion_type("OPAQUE") \ + .input(0, "indptr") \ + .output(0, "output") \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(csr2coo_op_info) +def _csr2coo_akg(): + """CSR2COO AutoDiff register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_gather.py b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_gather.py new file mode 100644 index 00000000000..c581cd38840 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_gather.py @@ -0,0 +1,33 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CSRGatherop""" +from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType + +csr_gather_op_info = AkgGpuRegOp("CSRGather") \ + .fusion_type("OPAQUE") \ + .input(0, "indptr") \ + .input(1, "indices") \ + .input(2, "dense") \ + .output(0, "output") \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \ + DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ + DataType.F32_Default) \ + .get_op_info() + +@op_info_register(csr_gather_op_info) +def _csr_gather_akg(): + """CSRGather AutoDiff register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_mv.py b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_mv.py index 190609f5a0f..7ecaf8b665c 100644 --- a/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_mv.py +++ b/mindspore/python/mindspore/ops/_op_impl/akg/gpu/csr_mv.py @@ -22,9 +22,6 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \ .input(2, "values") \ .input(4, "dense_tensor") \ .output(0, "output") \ - .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \ - DataType.F32_Default, \ - DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ DataType.F32_Default, \ DataType.F32_Default) \ diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py index fb95bd44688..90483a50875 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/add_impl.py @@ -346,4 +346,19 @@ def _add_nonetensor_tensor(x, y): """ return x + y + +@_add_backward.register("CSRTensor", "CSRTensor") +def _add_csrtensor_csrtensor(x, y): + """ + Adds CSRTensor and CSRTensor. + + Args: + x (CSRTensor): x + y (CSRTensor): y + + Returns: + CSRTensor. + """ + return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape) + hyper_add = base.HyperMap(_add_backward) diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py index bd46bc729d9..9900e22b6d2 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -58,6 +58,12 @@ def _ones_like_coo_tensor(x): return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x)) +@ones_like_leaf.register("CSRTensor") +def _ones_like_csr_tensor(x): + """Returns a tensor with the same shape and dtype as x and all elements are 1.""" + return F.make_csr_tensor(x.indptr, x.indices, ones_like(x.values), x.shape) + + @ones_like_leaf.register("Function") def _ones_like_func(x): """ diff --git a/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index b033ff27faa..33115a94f40 100644 --- a/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/python/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -58,6 +58,20 @@ def _zeros_like_tensor(x): return F.zeros_like(x) +@zeros_like_leaf.register("COOTensor") +def _zeros_like_coo_tensor(x): + """Returns a tensor with the same shape and dtype as x and all elements are 1.""" + values = F.zeros_like(x.values) + return F.make_coo_tensor(x.indices, values, x.shape) + + +@zeros_like_leaf.register("CSRTensor") +def _zeros_like_csr_tensor(x): + """Returns a tensor with the same shape and dtype as x and all elements are 1.""" + values = F.zeros_like(x.values) + return F.make_csr_tensor(x.indptr, x.indices, values, x.shape) + + @zeros_like_leaf.register("TypeType") def _zeros_like_type_type(x): """Returns x because x is a type. This is usually used in backprop progress.""" diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 05e61557a68..ec3626a4fb8 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -152,6 +152,9 @@ stack = P.Stack() csr_mul = _csr_ops.CSRMul() csr_mv = _csr_ops.CSRMV() csr_reduce_sum = _csr_ops.CSRReduceSum() +csr_gather = _csr_ops.CSRGather() +csr2coo = _csr_ops.CSR2COO() +coo2csr = _csr_ops.COO2CSR() _select = P.Select() @@ -576,6 +579,7 @@ not_in_dict = Primitive("not_in_dict") mixed_precision_cast = Primitive("mixed_precision_cast") broadcast_gradient_args = Primitive('BroadcastGradientArgs') array_reduce = Primitive('array_reduce') +zeros = P.Zeros() zeros_like = P.ZerosLike() distribute = Primitive('distribute') embed = Primitive('embed') @@ -670,6 +674,11 @@ tensor_operator_registry.register('log', log) tensor_operator_registry.register('floor', floor) # support sparse tensor operators tensor_operator_registry.register('csr_mul', csr_mul) +tensor_operator_registry.register('csr2coo', csr2coo) +tensor_operator_registry.register('coo2csr', coo2csr) tensor_operator_registry.register('narrow', narrow) +tensor_operator_registry.register('sort', sort) +tensor_operator_registry.register('zeros', zeros) +tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update) __all__ = [name for name in dir() if name[0] != "_"] __all__.remove('Primitive') diff --git a/mindspore/python/mindspore/ops/operations/_csr_ops.py b/mindspore/python/mindspore/ops/operations/_csr_ops.py index 963ca4130fd..247bcc68386 100644 --- a/mindspore/python/mindspore/ops/operations/_csr_ops.py +++ b/mindspore/python/mindspore/ops/operations/_csr_ops.py @@ -20,6 +20,9 @@ class CSRReduceSum(PrimitiveWithInfer): """ Reduces a dimension of a CSRTensor by summing all elements in the dimension. + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + Inputs: - **sparse_tensor** (CSRTensor) - A CSRTensor. - **axis** (int) - The dimensions to reduce. @@ -64,6 +67,9 @@ class CSRMV(PrimitiveWithInfer): """ Sparse matrix-vector multiplication. + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + Inputs: - **sparse_tensor** (CSRTensor) - A CSRTensor. - **dense_tensor** (Tensor) - A dense Tensor. @@ -109,6 +115,9 @@ class CSRMul(PrimitiveWithInfer): """ Elemwise multiplication on a CSRTensor and a dense tensor. + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + Note: The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`. If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor. @@ -151,3 +160,129 @@ class CSRMul(PrimitiveWithInfer): """Initialize CSRMul""" self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'], outputs=['output']) + + +class CSRGather(PrimitiveWithInfer): + """ + Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + Inputs: + - **indptr** (Tensor) - A Tensor. + - **indices** (Tensor) - A Tensor. + - **dense** (Tensor) - A Tensor. + - **sparse_shape** (tuple) - A tuple of integers. + + Outputs: + Tensor, the dtype is the same as `dense`, the first dimension is the same shape as `indices` and the remaining + dimensions are the same as ``dense[2:]``. + + Supported Platforms: + ``GPU`` + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, ops + >>> from mindspore import dtype as mstype + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.op = ops.CSRGather() + ... + ... def construct(self, indptr, indices, dense, sparse_shape): + ... return self.op(indptr, indices, dense, sparse_shape) + >>> indptr = Tensor([0, 1, 2]) + >>> indices = Tensor([0, 1]) + >>> sparse_shape = (2, 4) + >>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32) + >>> out = Net()(indptr, indices, dense, sparse_shape) + >>> print(out) + [1. 1.] + """ + + @prim_attr_register + def __init__(self): + """Initialize CSRGather""" + self.init_prim_io_names(inputs=['indptr', 'indices', 'dense', 'dense_shape'], + outputs=['output']) + + +class CSR2COO(PrimitiveWithInfer): + """ + Converts the indptr of a CSRTensor to the row indices of a COOTensor. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + Inputs: + - **indptr** (Tensor) - A Tensor. + - **nnz** (int) - Denotes the number of non-zero elements in the sparse tensor. + + Outputs: + Tensor, the dtype is the same as `indptr` and has shape (`nnz`,). + + Supported Platforms: + ``GPU`` + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, ops + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.op = ops.CSR2COO() + ... + ... def construct(self, indptr, nnz): + ... return self.op(indptr, nnz) + >>> indptr = Tensor([0, 1, 2]) + >>> out = Net()(indptr, 2) + >>> print(out) + [1 1] + """ + + @prim_attr_register + def __init__(self): + """Initialize CSR2COO""" + self.init_prim_io_names(inputs=['indptr', 'nnz'], outputs=['output']) + + +class COO2CSR(PrimitiveWithInfer): + """ + Converts the row indices of a COOTensor to the indptr of a CSRTensor. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + Inputs: + - **row_indices** (Tensor) - A Tensor. + - **height** (int) - the height of the first dimension of the sparse tensor. + + Outputs: + Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,). + + Supported Platforms: + ``GPU`` + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore import Tensor, ops + >>> from mindspore import dtype as mstype + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.op = ops.COO2CSR() + ... + ... def construct(self, row_indices, height): + ... return self.op(row_indices, height) + >>> row_indices = Tensor([0, 1]) + >>> out = Net()(row_indices, 2) + >>> print(out) + [0 1 2] + """ + + @prim_attr_register + def __init__(self): + """Initialize COO2CSR""" + self.init_prim_io_names(inputs=['row_indices', 'height'], outputs=['output']) diff --git a/tests/st/sparse/test_coo.py b/tests/st/sparse/test_coo.py index 4c24549aa8d..ceccd5abbb6 100644 --- a/tests/st/sparse/test_coo.py +++ b/tests/st/sparse/test_coo.py @@ -91,3 +91,39 @@ def test_coo_tensor_in_while(): assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0) assert out.shape == shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_coo_method(): + """ + Feature: Test coo tensor methods. + Description: Test coo_tensor.to_csr(), coo_tensor.to_dense(). + Expectation: Success. + """ + class COOToCSRNet(nn.Cell): + def construct(self, coo_tensor): + return coo_tensor.to_csr() + + class COOToDenseNet(nn.Cell): + def construct(self, coo_tensor): + return coo_tensor.to_dense() + + indices = Tensor([[1, 2], [0, 1]], dtype=mstype.int32) + values = Tensor([2, 1], dtype=mstype.float32) + shape = (3, 4) + coo_tensor = COOTensor(indices, values, shape) + + to_csr_output = COOToCSRNet()(coo_tensor) + to_csr_expect_1 = np.array([0, 1, 2, 2], dtype=np.int32) + to_csr_expect_2 = np.array([1, 2], dtype=np.int32) + to_csr_expect_3 = np.array([1, 2], dtype=np.float32) + assert np.allclose(to_csr_output.indptr.asnumpy(), to_csr_expect_1) + assert np.allclose(to_csr_output.indices.asnumpy(), to_csr_expect_2) + assert np.allclose(to_csr_output.values.asnumpy(), to_csr_expect_3) + + to_dense_output = COOToDenseNet()(coo_tensor) + to_dense_expect = np.array( + [[0., 1., 0., 0.], [0., 0., 2., 0.], [0., 0., 0., 0.]], dtype=np.float32) + assert np.allclose(to_dense_output.asnumpy(), to_dense_expect) diff --git a/tests/st/sparse/test_csr.py b/tests/st/sparse/test_csr.py index 7caa8ed61ca..981e9dac85b 100644 --- a/tests/st/sparse/test_csr.py +++ b/tests/st/sparse/test_csr.py @@ -18,7 +18,7 @@ import os import pytest import numpy as np -from mindspore import Tensor, CSRTensor, ms_function, nn, context +from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops from mindspore.ops.operations import _csr_ops from mindspore.common import dtype as mstype from mindspore.train.serialization import export, load @@ -232,8 +232,8 @@ def test_csr_ops(): csr_reducesum = _csr_ops.CSRReduceSum() csrmv = _csr_ops.CSRMV() - indptr = Tensor([0, 1, 2]) - indices = Tensor([0, 1]) + indptr = Tensor([0, 1, 2], dtype=mstype.int32) + indices = Tensor([0, 1], dtype=mstype.int32) values = Tensor([2, 1], dtype=mstype.float32) dense_shape = (2, 4) @@ -331,8 +331,8 @@ def test_csrops_export_and_import_mindir(): sparse2 = dence_tensor * csr_tensor return dense1, dense2, dense3, sparse1, sparse2 - indptr = Tensor([0, 1, 2]) - indices = Tensor([0, 1]) + indptr = Tensor([0, 1, 2], dtype=mstype.int32) + indices = Tensor([0, 1], dtype=mstype.int32) values = Tensor([2, 1], dtype=mstype.float32) shape = (2, 4) dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32) @@ -428,3 +428,111 @@ def test_dtype_csr_tensor(): out2 = graph_test() assert out1 in [mstype.float32] assert out2 in [mstype.float32] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_csr_bprop(): + """ + Feature: Test back-propagation with CSR-related Ops. + Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense(). + Expectation: Success. + """ + class CSRMulNet(nn.Cell): + def __init__(self): + super(CSRMulNet, self).__init__() + self.op = _csr_ops.CSRMul() + + def construct(self, csr_tensor, dense): + return self.op(csr_tensor, dense) + + class CSRReduceSumNet(nn.Cell): + def __init__(self): + super(CSRReduceSumNet, self).__init__() + self.op = _csr_ops.CSRReduceSum() + + def construct(self, csr_tensor, axis): + return self.op(csr_tensor, axis) + + class CSRMVNet(nn.Cell): + def __init__(self): + super(CSRMVNet, self).__init__() + self.op = _csr_ops.CSRMV() + + def construct(self, csr_tensor, dense): + return self.op(csr_tensor, dense) + + class BpropNet(nn.Cell): + def __init__(self, net): + super(BpropNet, self).__init__() + self.net = net + self.grad_op = ops.GradOperation(get_all=True) + + def construct(self, *inputs): + return self.grad_op(self.net)(*inputs) + + indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32) + indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32) + values = Tensor(np.arange(6), dtype=mstype.float32) + dense_shape = (3, 4) + csr_tensor = CSRTensor(indptr, indices, values, dense_shape) + + csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32) + csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg) + csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32) + csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32) + assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1) + assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2) + + csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1) + csr_reduce_sum_expect = np.ones(6, dtype=np.float32) + assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect) + + csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32) + csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1) + csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32) + csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32) + assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1) + assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2) + + csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32) + csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2) + csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32) + csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32) + assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1) + assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_csr_method(): + """ + Feature: Test csr tensor methods. + Description: Test csr_tensor.to_coo(), csr_tensor.to_dense(). + Expectation: Success. + """ + class CSRToCOONet(nn.Cell): + def construct(self, csr_tensor): + return csr_tensor.to_coo() + + class CSRToDenseNet(nn.Cell): + def construct(self, csr_tensor): + return csr_tensor.to_dense() + + indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32) + indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32) + values = Tensor(np.arange(6), dtype=mstype.float32) + dense_shape = (3, 4) + csr_tensor = CSRTensor(indptr, indices, values, dense_shape) + + to_coo_output = CSRToCOONet()(csr_tensor) + to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32) + to_coo_expect_2 = np.arange(6).astype(np.float32) + assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1) + assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2) + + to_dense_output = CSRToDenseNet()(csr_tensor) + to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32) + assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)