add csr bprop && csr method

This commit is contained in:
huangmengxi 2022-02-14 10:08:13 +08:00
parent 3dea54b28d
commit 080ad981d6
28 changed files with 709 additions and 27 deletions

View File

@ -58,6 +58,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimUnsortedSegmentMin->name(), {2}); Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(prim::kPrimUnsortedSegmentMax->name(), {2}); Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(prim::kPrimCSRReduceSum->name(), {1}); Register(prim::kPrimCSRReduceSum->name(), {1});
Register(prim::kPrimCSRGather->name(), {3});
Register(prim::kPrimCSR2COO->name(), {1});
Register(prim::kPrimCOO2CSR->name(), {1});
Register(kSparseGatherV2OpName, {2}); Register(kSparseGatherV2OpName, {2});
Register(kUnsortedSegmentProdOpName, {2}); Register(kUnsortedSegmentProdOpName, {2});
Register(kSimpleMeanGradOpName, {1}); Register(kSimpleMeanGradOpName, {1});

View File

@ -191,7 +191,6 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
} }
auto new_node = cnode->func_graph()->NewCNode(new_inputs); auto new_node = cnode->func_graph()->NewCNode(new_inputs);
new_node->set_abstract(node->abstract()); new_node->set_abstract(node->abstract());
AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node);
return new_node; return new_node;
} }

View File

@ -40,6 +40,8 @@ using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
using CSRTensor = mindspore::tensor::CSRTensor; using CSRTensor = mindspore::tensor::CSRTensor;
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr; using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
using COOTensor = mindspore::tensor::COOTensor;
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
using InstanceCheckFunc = std::function<bool(const py::object &)>; using InstanceCheckFunc = std::function<bool(const py::object &)>;
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>; using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
@ -489,6 +491,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>), std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>), std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>), std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple), std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList), std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>), std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),

View File

@ -98,6 +98,7 @@ namespace pipeline {
using Tensor = mindspore::tensor::Tensor; using Tensor = mindspore::tensor::Tensor;
using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensor = mindspore::tensor::MetaTensor;
using CSRTensor = mindspore::tensor::CSRTensor; using CSRTensor = mindspore::tensor::CSRTensor;
using COOTensor = mindspore::tensor::COOTensor;
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>; using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTensorPtr; using mindspore::abstract::AbstractTensorPtr;
@ -178,7 +179,8 @@ bool CheckArgValid(const py::handle &arg) {
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) || return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
py::isinstance<Number>(arg) || py::isinstance<Number>(arg) ||
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg)) && !py::hasattr(arg, "__parameter__")); ((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg)) &&
!py::hasattr(arg, "__parameter__"));
} }
std::string GetCompileExceptionInfo() { std::string GetCompileExceptionInfo() {

View File

@ -216,7 +216,17 @@ BuiltInTypeMap &GetMethodMap() {
}}, }},
{kObjectTypeJTagged, {}}, {kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}}, {kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}}; {kObjectTypeEnvType, {}},
{kObjectTypeCOOTensorType,
{
{"to_csr", std::string("coo_to_csr")},
{"to_dense", std::string("coo_to_dense")},
}},
{kObjectTypeCSRTensorType,
{
{"to_coo", std::string("csr_to_coo")},
{"to_dense", std::string("csr_to_dense")},
}}};
return method_map; return method_map;
} }

View File

@ -320,7 +320,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
bool IsCustomCSROP(const AnfNodePtr &cnode) { bool IsCustomCSROP(const AnfNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV}; const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV,
prim::kPrimCSRGather, prim::kPrimCSR2COO, prim::kPrimCOO2CSR};
return IsOneOfPrimitiveCNode(cnode, prims); return IsOneOfPrimitiveCNode(cnode, prims);
} }
} // namespace mindspore } // namespace mindspore

View File

@ -94,8 +94,13 @@ const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor},
// sparse_op_set records all sparse_compute operators, which takes sparsetensor // sparse_op_set records all sparse_compute operators, which takes sparsetensor
// and (possibly) dense tensors, used in backend common optimization pass: // and (possibly) dense tensors, used in backend common optimization pass:
// sparse_process.cc // sparse_process.cc
const mindspore::HashSet<std::string> sparse_op_set = { const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
{prim::kSparseTensorDenseMatmul}, {prim::kCSRDenseMul}, {prim::kCSRReduceSum}, {prim::kCSRMV}, {prim::kCSRMul}}; {prim::kCSRDenseMul},
{prim::kCSRReduceSum},
{prim::kCSRMV},
{prim::kCSRMul},
{prim::kCSRGather},
{prim::kCSR2COO}};
bool IsCustomCSROP(const AnfNodePtr &cnode); bool IsCustomCSROP(const AnfNodePtr &cnode);
} // namespace mindspore } // namespace mindspore

View File

@ -165,6 +165,12 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -36,6 +36,7 @@ namespace abstract {
constexpr auto kCSRDenseShape = "dense_shape"; constexpr auto kCSRDenseShape = "dense_shape";
constexpr auto kCSRAxis = "axis"; constexpr auto kCSRAxis = "axis";
constexpr auto kCSRAvgRows = "csr_avg_rows"; constexpr auto kCSRAvgRows = "csr_avg_rows";
constexpr auto kIsCSR = "is_csr";
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// An object of a subclass of AbstractBase // An object of a subclass of AbstractBase
@ -439,6 +440,9 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, " << "but sparse tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimensions, "; << "and dense tensor has " << dense_shape.size() << " dimensions, ";
} }
if (dense_shape[0] != sparse_shape[0]) {
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!";
}
auto ret = sparse->values()->Broaden(); auto ret = sparse->values()->Broaden();
MS_EXCEPTION_IF_NULL(sparse->indices()->shape()); MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
@ -446,7 +450,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]); int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret; return ret;
} }
@ -482,7 +486,7 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]); int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret; return ret;
} }
@ -532,7 +536,98 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]); int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows)); primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape)); primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
constexpr auto kCSRShapeSize = 2;
constexpr auto kCSRArgsSize = 4;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(dense);
MS_EXCEPTION_IF_NULL(sparse_shape);
if (sparse_shape->size() != kCSRShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
}
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto nnz_vec = indices->shape()->shape();
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indices->shape());
ShapeVector out_shape = indices->shape()->shape();
MS_EXCEPTION_IF_NULL(dense->element());
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(nnz);
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
ShapeVector out_shape;
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
out_shape.push_back(nnz_value);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
}
MS_EXCEPTION_IF_NULL(indptr->shape());
auto num_rows = indptr->shape()->shape()[0] - 1;
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indptr->element());
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(row_indices);
MS_EXCEPTION_IF_NULL(height);
MS_EXCEPTION_IF_NULL(height->BuildValue());
ShapeVector out_shape;
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
int64_t height_value = GetValue<int64_t>(height->BuildValue());
out_shape.push_back(height_value + 1);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
}
MS_EXCEPTION_IF_NULL(row_indices->element());
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
return ret; return ret;
} }

View File

@ -236,6 +236,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}}, {prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}}, {prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}}, {prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
// Comm Ops // Comm Ops
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}}, {prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}}, {prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},

View File

@ -164,6 +164,9 @@ constexpr auto kCSRDenseMul = "CSRDenseMul";
constexpr auto kCSRReduceSum = "CSRReduceSum"; constexpr auto kCSRReduceSum = "CSRReduceSum";
constexpr auto kCSRMV = "CSRMV"; constexpr auto kCSRMV = "CSRMV";
constexpr auto kCSRMul = "CSRMul"; constexpr auto kCSRMul = "CSRMul";
constexpr auto kCSRGather = "CSRGather";
constexpr auto kCSR2COO = "CSR2COO";
constexpr auto kCOO2CSR = "COO2CSR";
// Meta Function Graph // Meta Function Graph
constexpr auto kJ = "J"; constexpr auto kJ = "J";
@ -606,6 +609,9 @@ GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared<Primitive>(kCSRDenseMu
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum)); GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV)); GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul)); GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));
GVAR_DEF(PrimitivePtr, kPrimCOO2CSR, std::make_shared<Primitive>(kCOO2CSR));
// TensorList // TensorList
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor")); GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));

View File

@ -18,7 +18,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter, CSRTensor, COOTensor
from mindspore import dtype as mstype from mindspore import dtype as mstype
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
@ -1573,6 +1573,32 @@ def while_cond(x):
return x return x
def coo_to_csr(x):
row_indices = x.indices[:, 0]
col_indices = x.indices[:, 1]
idx_dtype = x.indices.dtype
row_indices, sort_idx = F.sort(row_indices.astype(mstype.float32))
row_indices = row_indices.astype(idx_dtype)
col_indices = col_indices[sort_idx]
values = x.values[sort_idx]
indptr = F.coo2csr(row_indices, x.shape[0])
return CSRTensor(indptr, col_indices, values, x.shape)
def coo_to_dense(x):
zeros_tensor = F.zeros(x.shape, x.values.dtype)
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)
def csr_to_coo(x):
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
coo_indices = P.Stack(1)((row_indices, x.indices))
return COOTensor(coo_indices, x.values, x.shape)
def csr_to_dense(x):
coo_tensor = x.to_coo()
return coo_tensor.to_dense()
@constexpr @constexpr
def empty_tensor(dtype): def empty_tensor(dtype):
return Tensor([], dtype) return Tensor([], dtype)

View File

@ -290,7 +290,7 @@ class _MindsporeFunctionExecutor:
return None return None
new_inputs = [] new_inputs = []
for i in args_list: for i in args_list:
if isinstance(i, (Tensor, CSRTensor)): if isinstance(i, (Tensor, CSRTensor, COOTensor)):
new_inputs.append(i) new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i) new_inputs.append(i)

View File

@ -2468,6 +2468,24 @@ class COOTensor(COOTensor_):
def shape(self): def shape(self):
return self._shape return self._shape
def to_csr(self):
"Converts COOTensor to CSRTensor."
row_indices = self.indices[:, 0]
col_indices = self.indices[:, 1]
idx_dtype = self.indices.dtype
row_indices, sort_idx = tensor_operator_registry.get("sort")(
row_indices.astype(mstype.float32))
row_indices = row_indices.astype(idx_dtype)
col_indices = col_indices[sort_idx]
values = self.values[sort_idx]
indptr = tensor_operator_registry.get("coo2csr")(row_indices, self.shape[0])
return CSRTensor(indptr, col_indices, values, self.shape)
def to_dense(self):
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
return tensor_operator_registry.get("tensor_scatter_update")(
zeros_tensor, self.indices, self.values)
class CSRTensor(CSRTensor_): class CSRTensor(CSRTensor_):
""" """
@ -2566,6 +2584,15 @@ class CSRTensor(CSRTensor_):
def to_tuple(self): def to_tuple(self):
return self.indptr, self.indices, self.values, self.shape return self.indptr, self.indices, self.values, self.shape
def to_coo(self):
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):
coo_tensor = self.to_coo()
return coo_tensor.to_dense()
def _vm_compare(*args): def _vm_compare(*args):
"""Implement `vm_compare` for tensor.""" """Implement `vm_compare` for tensor."""

View File

@ -32,7 +32,7 @@ from .._checkparam import Validator
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from ..common.tensor import Tensor, CSRTensor from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..ops.operations import Cast from ..ops.operations import Cast
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
from ..ops.operations import _inner_ops as inner from ..ops.operations import _inner_ops as inner
@ -815,6 +815,8 @@ class Cell(Cell_):
if i.has_init: if i.has_init:
i.init_data() i.init_data()
new_inputs.append(i) new_inputs.append(i)
elif isinstance(i, COOTensor):
new_inputs.append(i)
elif isinstance(i, CSRTensor): elif isinstance(i, CSRTensor):
new_inputs.append(i) new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)): elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
@ -1267,16 +1269,16 @@ class Cell(Cell_):
def _add_mixed_precision_flag(self, **flags): def _add_mixed_precision_flag(self, **flags):
"""Add mixed precision flag to current cell""" """Add mixed precision flag to current cell"""
if "fp16" in flags and flags["fp16"]: if "fp16" in flags and flags.get("fp16", False):
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16) Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
if "fp32" in flags and flags["fp32"]: if "fp32" in flags and flags.get("fp32", False):
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32) Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
def _add_mixed_precision_flag_recursive(self, **flags): def _add_mixed_precision_flag_recursive(self, **flags):
"""Add mixed precision flag to each cell""" """Add mixed precision flag to each cell"""
if "fp16" in flags and flags["fp16"]: if "fp16" in flags and flags.get("fp16", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16) self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
if "fp32" in flags and flags["fp32"]: if "fp32" in flags and flags.get("fp32", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32) self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
def add_flags(self, **flags): def add_flags(self, **flags):
@ -1876,15 +1878,16 @@ class Cell(Cell_):
""" """
self._recompute() self._recompute()
if 'mp_comm_recompute' in kwargs.keys(): if 'mp_comm_recompute' in kwargs.keys():
self._mp_comm_recompute(kwargs['mp_comm_recompute']) self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
if 'parallel_optimizer_comm_recompute' in kwargs.keys(): if 'parallel_optimizer_comm_recompute' in kwargs.keys():
if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1: if (kwargs.get('parallel_optimizer_comm_recompute', False) and
context.get_auto_parallel_context("pipeline_stages") > 1):
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard " logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
"are not support recomputation in pipeline parallel.") "are not support recomputation in pipeline parallel.")
elif context.get_auto_parallel_context("pipeline_stages") == 1: elif context.get_auto_parallel_context("pipeline_stages") == 1:
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute']) self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
if 'recompute_slice_activation' in kwargs.keys(): if 'recompute_slice_activation' in kwargs.keys():
self._recompute_slice_activation(kwargs['recompute_slice_activation']) self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
for key, _ in kwargs.items(): for key, _ in kwargs.items():
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'): if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):

View File

@ -14,8 +14,10 @@
# ============================================================================ # ============================================================================
"""bprop primitives""" """bprop primitives"""
from ...common import dtype as mstype
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
from ..operations import _csr_ops
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprops, bprop_getters from .grad_base import bprops, bprop_getters
@ -78,3 +80,80 @@ def get_bprop_sparse_tensor_dense_matmul(self):
values_grad = F.reduce_sum(parts_a * parts_b, 1) values_grad = F.reduce_sum(parts_a * parts_b, 1)
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
return bprop return bprop
@bprop_getters.register(_csr_ops.CSRReduceSum)
def get_bprop_csr_reduce_sum(self):
"Back-propagation for CSRReduceSum."
def bprop(csr_tensor, axis, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
shape = csr_tensor.shape
output_shape_kept_dims = F.reduced_shape(shape, axis)
tile_scaling = F.tuple_div(shape, output_shape_kept_dims)
values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling)
values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape)
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
return bprop
@bprop_getters.register(_csr_ops.CSRMV)
def get_bprop_csr_mv(self):
"Back-propagation for CSRMV."
def bprop(csr_tensor, dense, out, dout):
indptr = F.csr_tensor_get_indptr(csr_tensor)
indices = F.csr_tensor_get_indices(csr_tensor)
values = F.csr_tensor_get_values(csr_tensor)
dense_shape = csr_tensor.shape
rows = F.csr2coo(indptr, indices.shape[0])
idx_dtype = rows.dtype
rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32))
rows_transposed = rows_transposed.astype(idx_dtype)
cols_transposed = rows[cols_indexing]
values_transposed = values[cols_indexing]
indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1])
csr_tensor_transposed = F.make_csr_tensor(
indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0]))
dense_grad = F.csr_mv(csr_tensor_transposed, dout)
parts_a = F.gather(dout, rows, 0)
parts_b = F.gather(dense, indices, 0)
values_grad = F.reduce_sum(parts_a * parts_b, 1)
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSRMul)
def get_bprop_csr_mul(self):
"Back-propagation for CSRMul."
def bprop(csr_tensor, dense, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
values = csr_tensor.values
shape = csr_tensor.shape
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense)
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
dense_grad_value = F.mul(dout, values)
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
if len(dense.shape) == 1 or dense.shape[0] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 0)
elif dense.shape[1] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 1)
else:
row = F.csr2coo(indptr, indices.shape[0])
coo_idx = P.Stack(-1)((row, indices))
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
return csr_tensor_grad, dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSR2COO)
def get_bprop_csr2coo(self):
def bprop(indptr, nnz, out, dout):
return zeros_like(dout)
return bprop
@bprop_getters.register(_csr_ops.COO2CSR)
def get_bprop_coo2csr(self):
def bprop(row_indices, height, out, dout):
return zeros_like(dout)
return bprop

View File

@ -25,4 +25,7 @@ from .notequal import _notequal_akg
from .csr_reduce_sum import _csr_reduce_sum_akg from .csr_reduce_sum import _csr_reduce_sum_akg
from .csr_mv import _csr_mv_akg from .csr_mv import _csr_mv_akg
from .csr_mul import _csr_mul_akg from .csr_mul import _csr_mul_akg
from .csr_gather import _csr_gather_akg
from .csr2coo import _csr2coo_akg
from .coo2csr import _coo2csr_akg
# Please insert op register in lexicographical order of the filename. # Please insert op register in lexicographical order of the filename.

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COO2CSR op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
coo2csr_op_info = AkgGpuRegOp("COO2CSR") \
.fusion_type("OPAQUE") \
.input(0, "row_indices") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(coo2csr_op_info)
def _coo2csr_akg():
"""COO2CSR AutoDiff register"""
return

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSR2COO op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
csr2coo_op_info = AkgGpuRegOp("CSR2COO") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(csr2coo_op_info)
def _csr2coo_akg():
"""CSR2COO AutoDiff register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSRGatherop"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
csr_gather_op_info = AkgGpuRegOp("CSRGather") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.input(1, "indices") \
.input(2, "dense") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default) \
.get_op_info()
@op_info_register(csr_gather_op_info)
def _csr_gather_akg():
"""CSRGather AutoDiff register"""
return

View File

@ -22,9 +22,6 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
.input(2, "values") \ .input(2, "values") \
.input(4, "dense_tensor") \ .input(4, "dense_tensor") \
.output(0, "output") \ .output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default, \ DataType.F32_Default, \
DataType.F32_Default) \ DataType.F32_Default) \

View File

@ -346,4 +346,19 @@ def _add_nonetensor_tensor(x, y):
""" """
return x + y return x + y
@_add_backward.register("CSRTensor", "CSRTensor")
def _add_csrtensor_csrtensor(x, y):
"""
Adds CSRTensor and CSRTensor.
Args:
x (CSRTensor): x
y (CSRTensor): y
Returns:
CSRTensor.
"""
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
hyper_add = base.HyperMap(_add_backward) hyper_add = base.HyperMap(_add_backward)

View File

@ -58,6 +58,12 @@ def _ones_like_coo_tensor(x):
return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x)) return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x))
@ones_like_leaf.register("CSRTensor")
def _ones_like_csr_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
return F.make_csr_tensor(x.indptr, x.indices, ones_like(x.values), x.shape)
@ones_like_leaf.register("Function") @ones_like_leaf.register("Function")
def _ones_like_func(x): def _ones_like_func(x):
""" """

View File

@ -58,6 +58,20 @@ def _zeros_like_tensor(x):
return F.zeros_like(x) return F.zeros_like(x)
@zeros_like_leaf.register("COOTensor")
def _zeros_like_coo_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
values = F.zeros_like(x.values)
return F.make_coo_tensor(x.indices, values, x.shape)
@zeros_like_leaf.register("CSRTensor")
def _zeros_like_csr_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
values = F.zeros_like(x.values)
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
@zeros_like_leaf.register("TypeType") @zeros_like_leaf.register("TypeType")
def _zeros_like_type_type(x): def _zeros_like_type_type(x):
"""Returns x because x is a type. This is usually used in backprop progress.""" """Returns x because x is a type. This is usually used in backprop progress."""

View File

@ -152,6 +152,9 @@ stack = P.Stack()
csr_mul = _csr_ops.CSRMul() csr_mul = _csr_ops.CSRMul()
csr_mv = _csr_ops.CSRMV() csr_mv = _csr_ops.CSRMV()
csr_reduce_sum = _csr_ops.CSRReduceSum() csr_reduce_sum = _csr_ops.CSRReduceSum()
csr_gather = _csr_ops.CSRGather()
csr2coo = _csr_ops.CSR2COO()
coo2csr = _csr_ops.COO2CSR()
_select = P.Select() _select = P.Select()
@ -576,6 +579,7 @@ not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast") mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs') broadcast_gradient_args = Primitive('BroadcastGradientArgs')
array_reduce = Primitive('array_reduce') array_reduce = Primitive('array_reduce')
zeros = P.Zeros()
zeros_like = P.ZerosLike() zeros_like = P.ZerosLike()
distribute = Primitive('distribute') distribute = Primitive('distribute')
embed = Primitive('embed') embed = Primitive('embed')
@ -670,6 +674,11 @@ tensor_operator_registry.register('log', log)
tensor_operator_registry.register('floor', floor) tensor_operator_registry.register('floor', floor)
# support sparse tensor operators # support sparse tensor operators
tensor_operator_registry.register('csr_mul', csr_mul) tensor_operator_registry.register('csr_mul', csr_mul)
tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)
tensor_operator_registry.register('narrow', narrow) tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
__all__ = [name for name in dir() if name[0] != "_"] __all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive') __all__.remove('Primitive')

View File

@ -20,6 +20,9 @@ class CSRReduceSum(PrimitiveWithInfer):
""" """
Reduces a dimension of a CSRTensor by summing all elements in the dimension. Reduces a dimension of a CSRTensor by summing all elements in the dimension.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **sparse_tensor** (CSRTensor) - A CSRTensor. - **sparse_tensor** (CSRTensor) - A CSRTensor.
- **axis** (int) - The dimensions to reduce. - **axis** (int) - The dimensions to reduce.
@ -64,6 +67,9 @@ class CSRMV(PrimitiveWithInfer):
""" """
Sparse matrix-vector multiplication. Sparse matrix-vector multiplication.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **sparse_tensor** (CSRTensor) - A CSRTensor. - **sparse_tensor** (CSRTensor) - A CSRTensor.
- **dense_tensor** (Tensor) - A dense Tensor. - **dense_tensor** (Tensor) - A dense Tensor.
@ -109,6 +115,9 @@ class CSRMul(PrimitiveWithInfer):
""" """
Elemwise multiplication on a CSRTensor and a dense tensor. Elemwise multiplication on a CSRTensor and a dense tensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Note: Note:
The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`. The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`.
If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor. If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor.
@ -151,3 +160,129 @@ class CSRMul(PrimitiveWithInfer):
"""Initialize CSRMul""" """Initialize CSRMul"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'], self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
outputs=['output']) outputs=['output'])
class CSRGather(PrimitiveWithInfer):
"""
Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **indptr** (Tensor) - A Tensor.
- **indices** (Tensor) - A Tensor.
- **dense** (Tensor) - A Tensor.
- **sparse_shape** (tuple) - A tuple of integers.
Outputs:
Tensor, the dtype is the same as `dense`, the first dimension is the same shape as `indices` and the remaining
dimensions are the same as ``dense[2:]``.
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CSRGather()
...
... def construct(self, indptr, indices, dense, sparse_shape):
... return self.op(indptr, indices, dense, sparse_shape)
>>> indptr = Tensor([0, 1, 2])
>>> indices = Tensor([0, 1])
>>> sparse_shape = (2, 4)
>>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
>>> out = Net()(indptr, indices, dense, sparse_shape)
>>> print(out)
[1. 1.]
"""
@prim_attr_register
def __init__(self):
"""Initialize CSRGather"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'dense', 'dense_shape'],
outputs=['output'])
class CSR2COO(PrimitiveWithInfer):
"""
Converts the indptr of a CSRTensor to the row indices of a COOTensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **indptr** (Tensor) - A Tensor.
- **nnz** (int) - Denotes the number of non-zero elements in the sparse tensor.
Outputs:
Tensor, the dtype is the same as `indptr` and has shape (`nnz`,).
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CSR2COO()
...
... def construct(self, indptr, nnz):
... return self.op(indptr, nnz)
>>> indptr = Tensor([0, 1, 2])
>>> out = Net()(indptr, 2)
>>> print(out)
[1 1]
"""
@prim_attr_register
def __init__(self):
"""Initialize CSR2COO"""
self.init_prim_io_names(inputs=['indptr', 'nnz'], outputs=['output'])
class COO2CSR(PrimitiveWithInfer):
"""
Converts the row indices of a COOTensor to the indptr of a CSRTensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **row_indices** (Tensor) - A Tensor.
- **height** (int) - the height of the first dimension of the sparse tensor.
Outputs:
Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,).
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.COO2CSR()
...
... def construct(self, row_indices, height):
... return self.op(row_indices, height)
>>> row_indices = Tensor([0, 1])
>>> out = Net()(row_indices, 2)
>>> print(out)
[0 1 2]
"""
@prim_attr_register
def __init__(self):
"""Initialize COO2CSR"""
self.init_prim_io_names(inputs=['row_indices', 'height'], outputs=['output'])

View File

@ -91,3 +91,39 @@ def test_coo_tensor_in_while():
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0) assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0) assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
assert out.shape == shape assert out.shape == shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_coo_method():
"""
Feature: Test coo tensor methods.
Description: Test coo_tensor.to_csr(), coo_tensor.to_dense().
Expectation: Success.
"""
class COOToCSRNet(nn.Cell):
def construct(self, coo_tensor):
return coo_tensor.to_csr()
class COOToDenseNet(nn.Cell):
def construct(self, coo_tensor):
return coo_tensor.to_dense()
indices = Tensor([[1, 2], [0, 1]], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32)
shape = (3, 4)
coo_tensor = COOTensor(indices, values, shape)
to_csr_output = COOToCSRNet()(coo_tensor)
to_csr_expect_1 = np.array([0, 1, 2, 2], dtype=np.int32)
to_csr_expect_2 = np.array([1, 2], dtype=np.int32)
to_csr_expect_3 = np.array([1, 2], dtype=np.float32)
assert np.allclose(to_csr_output.indptr.asnumpy(), to_csr_expect_1)
assert np.allclose(to_csr_output.indices.asnumpy(), to_csr_expect_2)
assert np.allclose(to_csr_output.values.asnumpy(), to_csr_expect_3)
to_dense_output = COOToDenseNet()(coo_tensor)
to_dense_expect = np.array(
[[0., 1., 0., 0.], [0., 0., 2., 0.], [0., 0., 0., 0.]], dtype=np.float32)
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)

View File

@ -18,7 +18,7 @@ import os
import pytest import pytest
import numpy as np import numpy as np
from mindspore import Tensor, CSRTensor, ms_function, nn, context from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
from mindspore.ops.operations import _csr_ops from mindspore.ops.operations import _csr_ops
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load from mindspore.train.serialization import export, load
@ -232,8 +232,8 @@ def test_csr_ops():
csr_reducesum = _csr_ops.CSRReduceSum() csr_reducesum = _csr_ops.CSRReduceSum()
csrmv = _csr_ops.CSRMV() csrmv = _csr_ops.CSRMV()
indptr = Tensor([0, 1, 2]) indptr = Tensor([0, 1, 2], dtype=mstype.int32)
indices = Tensor([0, 1]) indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32) values = Tensor([2, 1], dtype=mstype.float32)
dense_shape = (2, 4) dense_shape = (2, 4)
@ -331,8 +331,8 @@ def test_csrops_export_and_import_mindir():
sparse2 = dence_tensor * csr_tensor sparse2 = dence_tensor * csr_tensor
return dense1, dense2, dense3, sparse1, sparse2 return dense1, dense2, dense3, sparse1, sparse2
indptr = Tensor([0, 1, 2]) indptr = Tensor([0, 1, 2], dtype=mstype.int32)
indices = Tensor([0, 1]) indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32) values = Tensor([2, 1], dtype=mstype.float32)
shape = (2, 4) shape = (2, 4)
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32) dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
@ -428,3 +428,111 @@ def test_dtype_csr_tensor():
out2 = graph_test() out2 = graph_test()
assert out1 in [mstype.float32] assert out1 in [mstype.float32]
assert out2 in [mstype.float32] assert out2 in [mstype.float32]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_csr_bprop():
"""
Feature: Test back-propagation with CSR-related Ops.
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
Expectation: Success.
"""
class CSRMulNet(nn.Cell):
def __init__(self):
super(CSRMulNet, self).__init__()
self.op = _csr_ops.CSRMul()
def construct(self, csr_tensor, dense):
return self.op(csr_tensor, dense)
class CSRReduceSumNet(nn.Cell):
def __init__(self):
super(CSRReduceSumNet, self).__init__()
self.op = _csr_ops.CSRReduceSum()
def construct(self, csr_tensor, axis):
return self.op(csr_tensor, axis)
class CSRMVNet(nn.Cell):
def __init__(self):
super(CSRMVNet, self).__init__()
self.op = _csr_ops.CSRMV()
def construct(self, csr_tensor, dense):
return self.op(csr_tensor, dense)
class BpropNet(nn.Cell):
def __init__(self, net):
super(BpropNet, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True)
def construct(self, *inputs):
return self.grad_op(self.net)(*inputs)
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6), dtype=mstype.float32)
dense_shape = (3, 4)
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg)
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1)
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1)
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2)
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_csr_method():
"""
Feature: Test csr tensor methods.
Description: Test csr_tensor.to_coo(), csr_tensor.to_dense().
Expectation: Success.
"""
class CSRToCOONet(nn.Cell):
def construct(self, csr_tensor):
return csr_tensor.to_coo()
class CSRToDenseNet(nn.Cell):
def construct(self, csr_tensor):
return csr_tensor.to_dense()
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6), dtype=mstype.float32)
dense_shape = (3, 4)
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
to_coo_output = CSRToCOONet()(csr_tensor)
to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32)
to_coo_expect_2 = np.arange(6).astype(np.float32)
assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1)
assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2)
to_dense_output = CSRToDenseNet()(csr_tensor)
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)