add csr bprop && csr method

This commit is contained in:
huangmengxi 2022-02-14 10:08:13 +08:00
parent 3dea54b28d
commit 080ad981d6
28 changed files with 709 additions and 27 deletions

View File

@ -58,6 +58,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
Register(prim::kPrimCSRReduceSum->name(), {1});
Register(prim::kPrimCSRGather->name(), {3});
Register(prim::kPrimCSR2COO->name(), {1});
Register(prim::kPrimCOO2CSR->name(), {1});
Register(kSparseGatherV2OpName, {2});
Register(kUnsortedSegmentProdOpName, {2});
Register(kSimpleMeanGradOpName, {1});

View File

@ -191,7 +191,6 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
}
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
new_node->set_abstract(node->abstract());
AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node);
return new_node;
}

View File

@ -40,6 +40,8 @@ using MetaTensor = mindspore::tensor::MetaTensor;
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
using CSRTensor = mindspore::tensor::CSRTensor;
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
using COOTensor = mindspore::tensor::COOTensor;
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
using InstanceCheckFunc = std::function<bool(const py::object &)>;
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
@ -489,6 +491,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),

View File

@ -98,6 +98,7 @@ namespace pipeline {
using Tensor = mindspore::tensor::Tensor;
using MetaTensor = mindspore::tensor::MetaTensor;
using CSRTensor = mindspore::tensor::CSRTensor;
using COOTensor = mindspore::tensor::COOTensor;
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTensorPtr;
@ -178,7 +179,8 @@ bool CheckArgValid(const py::handle &arg) {
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
py::isinstance<Number>(arg) ||
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg)) && !py::hasattr(arg, "__parameter__"));
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg)) &&
!py::hasattr(arg, "__parameter__"));
}
std::string GetCompileExceptionInfo() {

View File

@ -216,7 +216,17 @@ BuiltInTypeMap &GetMethodMap() {
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}};
{kObjectTypeEnvType, {}},
{kObjectTypeCOOTensorType,
{
{"to_csr", std::string("coo_to_csr")},
{"to_dense", std::string("coo_to_dense")},
}},
{kObjectTypeCSRTensorType,
{
{"to_coo", std::string("csr_to_coo")},
{"to_dense", std::string("csr_to_dense")},
}}};
return method_map;
}

View File

@ -320,7 +320,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
bool IsCustomCSROP(const AnfNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV};
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV,
prim::kPrimCSRGather, prim::kPrimCSR2COO, prim::kPrimCOO2CSR};
return IsOneOfPrimitiveCNode(cnode, prims);
}
} // namespace mindspore

View File

@ -94,8 +94,13 @@ const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor},
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
// and (possibly) dense tensors, used in backend common optimization pass:
// sparse_process.cc
const mindspore::HashSet<std::string> sparse_op_set = {
{prim::kSparseTensorDenseMatmul}, {prim::kCSRDenseMul}, {prim::kCSRReduceSum}, {prim::kCSRMV}, {prim::kCSRMul}};
const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
{prim::kCSRDenseMul},
{prim::kCSRReduceSum},
{prim::kCSRMV},
{prim::kCSRMul},
{prim::kCSRGather},
{prim::kCSR2COO}};
bool IsCustomCSROP(const AnfNodePtr &cnode);
} // namespace mindspore

View File

@ -165,6 +165,12 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -36,6 +36,7 @@ namespace abstract {
constexpr auto kCSRDenseShape = "dense_shape";
constexpr auto kCSRAxis = "axis";
constexpr auto kCSRAvgRows = "csr_avg_rows";
constexpr auto kIsCSR = "is_csr";
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// An object of a subclass of AbstractBase
@ -439,6 +440,9 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
}
if (dense_shape[0] != sparse_shape[0]) {
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!";
}
auto ret = sparse->values()->Broaden();
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
@ -446,7 +450,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
@ -482,7 +486,7 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
@ -532,7 +536,98 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
constexpr auto kCSRShapeSize = 2;
constexpr auto kCSRArgsSize = 4;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(dense);
MS_EXCEPTION_IF_NULL(sparse_shape);
if (sparse_shape->size() != kCSRShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
}
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto nnz_vec = indices->shape()->shape();
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indices->shape());
ShapeVector out_shape = indices->shape()->shape();
MS_EXCEPTION_IF_NULL(dense->element());
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(nnz);
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
ShapeVector out_shape;
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
out_shape.push_back(nnz_value);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
}
MS_EXCEPTION_IF_NULL(indptr->shape());
auto num_rows = indptr->shape()->shape()[0] - 1;
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indptr->element());
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(row_indices);
MS_EXCEPTION_IF_NULL(height);
MS_EXCEPTION_IF_NULL(height->BuildValue());
ShapeVector out_shape;
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
int64_t height_value = GetValue<int64_t>(height->BuildValue());
out_shape.push_back(height_value + 1);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
}
MS_EXCEPTION_IF_NULL(row_indices->element());
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
return ret;
}

View File

@ -236,6 +236,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
// Comm Ops
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},

View File

@ -164,6 +164,9 @@ constexpr auto kCSRDenseMul = "CSRDenseMul";
constexpr auto kCSRReduceSum = "CSRReduceSum";
constexpr auto kCSRMV = "CSRMV";
constexpr auto kCSRMul = "CSRMul";
constexpr auto kCSRGather = "CSRGather";
constexpr auto kCSR2COO = "CSR2COO";
constexpr auto kCOO2CSR = "COO2CSR";
// Meta Function Graph
constexpr auto kJ = "J";
@ -606,6 +609,9 @@ GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared<Primitive>(kCSRDenseMu
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));
GVAR_DEF(PrimitivePtr, kPrimCOO2CSR, std::make_shared<Primitive>(kCOO2CSR));
// TensorList
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));

View File

@ -18,7 +18,7 @@
from dataclasses import dataclass
from mindspore import Tensor, Parameter
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
from mindspore import dtype as mstype
from ..._checkparam import Validator as validator
@ -1573,6 +1573,32 @@ def while_cond(x):
return x
def coo_to_csr(x):
row_indices = x.indices[:, 0]
col_indices = x.indices[:, 1]
idx_dtype = x.indices.dtype
row_indices, sort_idx = F.sort(row_indices.astype(mstype.float32))
row_indices = row_indices.astype(idx_dtype)
col_indices = col_indices[sort_idx]
values = x.values[sort_idx]
indptr = F.coo2csr(row_indices, x.shape[0])
return CSRTensor(indptr, col_indices, values, x.shape)
def coo_to_dense(x):
zeros_tensor = F.zeros(x.shape, x.values.dtype)
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)
def csr_to_coo(x):
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
coo_indices = P.Stack(1)((row_indices, x.indices))
return COOTensor(coo_indices, x.values, x.shape)
def csr_to_dense(x):
coo_tensor = x.to_coo()
return coo_tensor.to_dense()
@constexpr
def empty_tensor(dtype):
return Tensor([], dtype)

View File

@ -290,7 +290,7 @@ class _MindsporeFunctionExecutor:
return None
new_inputs = []
for i in args_list:
if isinstance(i, (Tensor, CSRTensor)):
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
new_inputs.append(i)

View File

@ -2468,6 +2468,24 @@ class COOTensor(COOTensor_):
def shape(self):
return self._shape
def to_csr(self):
"Converts COOTensor to CSRTensor."
row_indices = self.indices[:, 0]
col_indices = self.indices[:, 1]
idx_dtype = self.indices.dtype
row_indices, sort_idx = tensor_operator_registry.get("sort")(
row_indices.astype(mstype.float32))
row_indices = row_indices.astype(idx_dtype)
col_indices = col_indices[sort_idx]
values = self.values[sort_idx]
indptr = tensor_operator_registry.get("coo2csr")(row_indices, self.shape[0])
return CSRTensor(indptr, col_indices, values, self.shape)
def to_dense(self):
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
return tensor_operator_registry.get("tensor_scatter_update")(
zeros_tensor, self.indices, self.values)
class CSRTensor(CSRTensor_):
"""
@ -2566,6 +2584,15 @@ class CSRTensor(CSRTensor_):
def to_tuple(self):
return self.indptr, self.indices, self.values, self.shape
def to_coo(self):
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
return COOTensor(coo_indices, self.values, self.shape)
def to_dense(self):
coo_tensor = self.to_coo()
return coo_tensor.to_dense()
def _vm_compare(*args):
"""Implement `vm_compare` for tensor."""

View File

@ -32,7 +32,7 @@ from .._checkparam import Validator
from ..common import dtype as mstype
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
from ..common.parameter import Parameter, ParameterTuple
from ..common.tensor import Tensor, CSRTensor
from ..common.tensor import Tensor, CSRTensor, COOTensor
from ..ops.operations import Cast
from ..ops.primitive import Primitive
from ..ops.operations import _inner_ops as inner
@ -815,6 +815,8 @@ class Cell(Cell_):
if i.has_init:
i.init_data()
new_inputs.append(i)
elif isinstance(i, COOTensor):
new_inputs.append(i)
elif isinstance(i, CSRTensor):
new_inputs.append(i)
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
@ -1267,16 +1269,16 @@ class Cell(Cell_):
def _add_mixed_precision_flag(self, **flags):
"""Add mixed precision flag to current cell"""
if "fp16" in flags and flags["fp16"]:
if "fp16" in flags and flags.get("fp16", False):
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
if "fp32" in flags and flags["fp32"]:
if "fp32" in flags and flags.get("fp32", False):
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
def _add_mixed_precision_flag_recursive(self, **flags):
"""Add mixed precision flag to each cell"""
if "fp16" in flags and flags["fp16"]:
if "fp16" in flags and flags.get("fp16", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
if "fp32" in flags and flags["fp32"]:
if "fp32" in flags and flags.get("fp32", False):
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
def add_flags(self, **flags):
@ -1876,15 +1878,16 @@ class Cell(Cell_):
"""
self._recompute()
if 'mp_comm_recompute' in kwargs.keys():
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1:
if (kwargs.get('parallel_optimizer_comm_recompute', False) and
context.get_auto_parallel_context("pipeline_stages") > 1):
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
"are not support recomputation in pipeline parallel.")
elif context.get_auto_parallel_context("pipeline_stages") == 1:
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
if 'recompute_slice_activation' in kwargs.keys():
self._recompute_slice_activation(kwargs['recompute_slice_activation'])
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
for key, _ in kwargs.items():
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):

View File

@ -14,8 +14,10 @@
# ============================================================================
"""bprop primitives"""
from ...common import dtype as mstype
from .. import functional as F
from .. import operations as P
from ..operations import _csr_ops
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprops, bprop_getters
@ -78,3 +80,80 @@ def get_bprop_sparse_tensor_dense_matmul(self):
values_grad = F.reduce_sum(parts_a * parts_b, 1)
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSRReduceSum)
def get_bprop_csr_reduce_sum(self):
"Back-propagation for CSRReduceSum."
def bprop(csr_tensor, axis, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
shape = csr_tensor.shape
output_shape_kept_dims = F.reduced_shape(shape, axis)
tile_scaling = F.tuple_div(shape, output_shape_kept_dims)
values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling)
values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape)
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
return bprop
@bprop_getters.register(_csr_ops.CSRMV)
def get_bprop_csr_mv(self):
"Back-propagation for CSRMV."
def bprop(csr_tensor, dense, out, dout):
indptr = F.csr_tensor_get_indptr(csr_tensor)
indices = F.csr_tensor_get_indices(csr_tensor)
values = F.csr_tensor_get_values(csr_tensor)
dense_shape = csr_tensor.shape
rows = F.csr2coo(indptr, indices.shape[0])
idx_dtype = rows.dtype
rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32))
rows_transposed = rows_transposed.astype(idx_dtype)
cols_transposed = rows[cols_indexing]
values_transposed = values[cols_indexing]
indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1])
csr_tensor_transposed = F.make_csr_tensor(
indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0]))
dense_grad = F.csr_mv(csr_tensor_transposed, dout)
parts_a = F.gather(dout, rows, 0)
parts_b = F.gather(dense, indices, 0)
values_grad = F.reduce_sum(parts_a * parts_b, 1)
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSRMul)
def get_bprop_csr_mul(self):
"Back-propagation for CSRMul."
def bprop(csr_tensor, dense, out, dout):
indptr = csr_tensor.indptr
indices = csr_tensor.indices
values = csr_tensor.values
shape = csr_tensor.shape
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense)
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
dense_grad_value = F.mul(dout, values)
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
if len(dense.shape) == 1 or dense.shape[0] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 0)
elif dense.shape[1] == 1:
dense_grad = F.csr_reduce_sum(dense_grad, 1)
else:
row = F.csr2coo(indptr, indices.shape[0])
coo_idx = P.Stack(-1)((row, indices))
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
return csr_tensor_grad, dense_grad
return bprop
@bprop_getters.register(_csr_ops.CSR2COO)
def get_bprop_csr2coo(self):
def bprop(indptr, nnz, out, dout):
return zeros_like(dout)
return bprop
@bprop_getters.register(_csr_ops.COO2CSR)
def get_bprop_coo2csr(self):
def bprop(row_indices, height, out, dout):
return zeros_like(dout)
return bprop

View File

@ -25,4 +25,7 @@ from .notequal import _notequal_akg
from .csr_reduce_sum import _csr_reduce_sum_akg
from .csr_mv import _csr_mv_akg
from .csr_mul import _csr_mul_akg
from .csr_gather import _csr_gather_akg
from .csr2coo import _csr2coo_akg
from .coo2csr import _coo2csr_akg
# Please insert op register in lexicographical order of the filename.

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COO2CSR op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
coo2csr_op_info = AkgGpuRegOp("COO2CSR") \
.fusion_type("OPAQUE") \
.input(0, "row_indices") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(coo2csr_op_info)
def _coo2csr_akg():
"""COO2CSR AutoDiff register"""
return

View File

@ -0,0 +1,29 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSR2COO op"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
csr2coo_op_info = AkgGpuRegOp("CSR2COO") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(csr2coo_op_info)
def _csr2coo_akg():
"""CSR2COO AutoDiff register"""
return

View File

@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CSRGatherop"""
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
csr_gather_op_info = AkgGpuRegOp("CSRGather") \
.fusion_type("OPAQUE") \
.input(0, "indptr") \
.input(1, "indices") \
.input(2, "dense") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default) \
.get_op_info()
@op_info_register(csr_gather_op_info)
def _csr_gather_akg():
"""CSRGather AutoDiff register"""
return

View File

@ -22,9 +22,6 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
.input(2, "values") \
.input(4, "dense_tensor") \
.output(0, "output") \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
DataType.F32_Default, \
DataType.F32_Default) \

View File

@ -346,4 +346,19 @@ def _add_nonetensor_tensor(x, y):
"""
return x + y
@_add_backward.register("CSRTensor", "CSRTensor")
def _add_csrtensor_csrtensor(x, y):
"""
Adds CSRTensor and CSRTensor.
Args:
x (CSRTensor): x
y (CSRTensor): y
Returns:
CSRTensor.
"""
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
hyper_add = base.HyperMap(_add_backward)

View File

@ -58,6 +58,12 @@ def _ones_like_coo_tensor(x):
return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x))
@ones_like_leaf.register("CSRTensor")
def _ones_like_csr_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
return F.make_csr_tensor(x.indptr, x.indices, ones_like(x.values), x.shape)
@ones_like_leaf.register("Function")
def _ones_like_func(x):
"""

View File

@ -58,6 +58,20 @@ def _zeros_like_tensor(x):
return F.zeros_like(x)
@zeros_like_leaf.register("COOTensor")
def _zeros_like_coo_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
values = F.zeros_like(x.values)
return F.make_coo_tensor(x.indices, values, x.shape)
@zeros_like_leaf.register("CSRTensor")
def _zeros_like_csr_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
values = F.zeros_like(x.values)
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
@zeros_like_leaf.register("TypeType")
def _zeros_like_type_type(x):
"""Returns x because x is a type. This is usually used in backprop progress."""

View File

@ -152,6 +152,9 @@ stack = P.Stack()
csr_mul = _csr_ops.CSRMul()
csr_mv = _csr_ops.CSRMV()
csr_reduce_sum = _csr_ops.CSRReduceSum()
csr_gather = _csr_ops.CSRGather()
csr2coo = _csr_ops.CSR2COO()
coo2csr = _csr_ops.COO2CSR()
_select = P.Select()
@ -576,6 +579,7 @@ not_in_dict = Primitive("not_in_dict")
mixed_precision_cast = Primitive("mixed_precision_cast")
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
array_reduce = Primitive('array_reduce')
zeros = P.Zeros()
zeros_like = P.ZerosLike()
distribute = Primitive('distribute')
embed = Primitive('embed')
@ -670,6 +674,11 @@ tensor_operator_registry.register('log', log)
tensor_operator_registry.register('floor', floor)
# support sparse tensor operators
tensor_operator_registry.register('csr_mul', csr_mul)
tensor_operator_registry.register('csr2coo', csr2coo)
tensor_operator_registry.register('coo2csr', coo2csr)
tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('zeros', zeros)
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
__all__ = [name for name in dir() if name[0] != "_"]
__all__.remove('Primitive')

View File

@ -20,6 +20,9 @@ class CSRReduceSum(PrimitiveWithInfer):
"""
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **sparse_tensor** (CSRTensor) - A CSRTensor.
- **axis** (int) - The dimensions to reduce.
@ -64,6 +67,9 @@ class CSRMV(PrimitiveWithInfer):
"""
Sparse matrix-vector multiplication.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **sparse_tensor** (CSRTensor) - A CSRTensor.
- **dense_tensor** (Tensor) - A dense Tensor.
@ -109,6 +115,9 @@ class CSRMul(PrimitiveWithInfer):
"""
Elemwise multiplication on a CSRTensor and a dense tensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Note:
The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`.
If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor.
@ -151,3 +160,129 @@ class CSRMul(PrimitiveWithInfer):
"""Initialize CSRMul"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
outputs=['output'])
class CSRGather(PrimitiveWithInfer):
"""
Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **indptr** (Tensor) - A Tensor.
- **indices** (Tensor) - A Tensor.
- **dense** (Tensor) - A Tensor.
- **sparse_shape** (tuple) - A tuple of integers.
Outputs:
Tensor, the dtype is the same as `dense`, the first dimension is the same shape as `indices` and the remaining
dimensions are the same as ``dense[2:]``.
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CSRGather()
...
... def construct(self, indptr, indices, dense, sparse_shape):
... return self.op(indptr, indices, dense, sparse_shape)
>>> indptr = Tensor([0, 1, 2])
>>> indices = Tensor([0, 1])
>>> sparse_shape = (2, 4)
>>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
>>> out = Net()(indptr, indices, dense, sparse_shape)
>>> print(out)
[1. 1.]
"""
@prim_attr_register
def __init__(self):
"""Initialize CSRGather"""
self.init_prim_io_names(inputs=['indptr', 'indices', 'dense', 'dense_shape'],
outputs=['output'])
class CSR2COO(PrimitiveWithInfer):
"""
Converts the indptr of a CSRTensor to the row indices of a COOTensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **indptr** (Tensor) - A Tensor.
- **nnz** (int) - Denotes the number of non-zero elements in the sparse tensor.
Outputs:
Tensor, the dtype is the same as `indptr` and has shape (`nnz`,).
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.CSR2COO()
...
... def construct(self, indptr, nnz):
... return self.op(indptr, nnz)
>>> indptr = Tensor([0, 1, 2])
>>> out = Net()(indptr, 2)
>>> print(out)
[1 1]
"""
@prim_attr_register
def __init__(self):
"""Initialize CSR2COO"""
self.init_prim_io_names(inputs=['indptr', 'nnz'], outputs=['output'])
class COO2CSR(PrimitiveWithInfer):
"""
Converts the row indices of a COOTensor to the indptr of a CSRTensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **row_indices** (Tensor) - A Tensor.
- **height** (int) - the height of the first dimension of the sparse tensor.
Outputs:
Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,).
Supported Platforms:
``GPU``
Examples:
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.op = ops.COO2CSR()
...
... def construct(self, row_indices, height):
... return self.op(row_indices, height)
>>> row_indices = Tensor([0, 1])
>>> out = Net()(row_indices, 2)
>>> print(out)
[0 1 2]
"""
@prim_attr_register
def __init__(self):
"""Initialize COO2CSR"""
self.init_prim_io_names(inputs=['row_indices', 'height'], outputs=['output'])

View File

@ -91,3 +91,39 @@ def test_coo_tensor_in_while():
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
assert out.shape == shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_coo_method():
"""
Feature: Test coo tensor methods.
Description: Test coo_tensor.to_csr(), coo_tensor.to_dense().
Expectation: Success.
"""
class COOToCSRNet(nn.Cell):
def construct(self, coo_tensor):
return coo_tensor.to_csr()
class COOToDenseNet(nn.Cell):
def construct(self, coo_tensor):
return coo_tensor.to_dense()
indices = Tensor([[1, 2], [0, 1]], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32)
shape = (3, 4)
coo_tensor = COOTensor(indices, values, shape)
to_csr_output = COOToCSRNet()(coo_tensor)
to_csr_expect_1 = np.array([0, 1, 2, 2], dtype=np.int32)
to_csr_expect_2 = np.array([1, 2], dtype=np.int32)
to_csr_expect_3 = np.array([1, 2], dtype=np.float32)
assert np.allclose(to_csr_output.indptr.asnumpy(), to_csr_expect_1)
assert np.allclose(to_csr_output.indices.asnumpy(), to_csr_expect_2)
assert np.allclose(to_csr_output.values.asnumpy(), to_csr_expect_3)
to_dense_output = COOToDenseNet()(coo_tensor)
to_dense_expect = np.array(
[[0., 1., 0., 0.], [0., 0., 2., 0.], [0., 0., 0., 0.]], dtype=np.float32)
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)

View File

@ -18,7 +18,7 @@ import os
import pytest
import numpy as np
from mindspore import Tensor, CSRTensor, ms_function, nn, context
from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
from mindspore.ops.operations import _csr_ops
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export, load
@ -232,8 +232,8 @@ def test_csr_ops():
csr_reducesum = _csr_ops.CSRReduceSum()
csrmv = _csr_ops.CSRMV()
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32)
dense_shape = (2, 4)
@ -331,8 +331,8 @@ def test_csrops_export_and_import_mindir():
sparse2 = dence_tensor * csr_tensor
return dense1, dense2, dense3, sparse1, sparse2
indptr = Tensor([0, 1, 2])
indices = Tensor([0, 1])
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
indices = Tensor([0, 1], dtype=mstype.int32)
values = Tensor([2, 1], dtype=mstype.float32)
shape = (2, 4)
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
@ -428,3 +428,111 @@ def test_dtype_csr_tensor():
out2 = graph_test()
assert out1 in [mstype.float32]
assert out2 in [mstype.float32]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_csr_bprop():
"""
Feature: Test back-propagation with CSR-related Ops.
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
Expectation: Success.
"""
class CSRMulNet(nn.Cell):
def __init__(self):
super(CSRMulNet, self).__init__()
self.op = _csr_ops.CSRMul()
def construct(self, csr_tensor, dense):
return self.op(csr_tensor, dense)
class CSRReduceSumNet(nn.Cell):
def __init__(self):
super(CSRReduceSumNet, self).__init__()
self.op = _csr_ops.CSRReduceSum()
def construct(self, csr_tensor, axis):
return self.op(csr_tensor, axis)
class CSRMVNet(nn.Cell):
def __init__(self):
super(CSRMVNet, self).__init__()
self.op = _csr_ops.CSRMV()
def construct(self, csr_tensor, dense):
return self.op(csr_tensor, dense)
class BpropNet(nn.Cell):
def __init__(self, net):
super(BpropNet, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True)
def construct(self, *inputs):
return self.grad_op(self.net)(*inputs)
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6), dtype=mstype.float32)
dense_shape = (3, 4)
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg)
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1)
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1)
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2)
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_csr_method():
"""
Feature: Test csr tensor methods.
Description: Test csr_tensor.to_coo(), csr_tensor.to_dense().
Expectation: Success.
"""
class CSRToCOONet(nn.Cell):
def construct(self, csr_tensor):
return csr_tensor.to_coo()
class CSRToDenseNet(nn.Cell):
def construct(self, csr_tensor):
return csr_tensor.to_dense()
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(6), dtype=mstype.float32)
dense_shape = (3, 4)
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
to_coo_output = CSRToCOONet()(csr_tensor)
to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32)
to_coo_expect_2 = np.arange(6).astype(np.float32)
assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1)
assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2)
to_dense_output = CSRToDenseNet()(csr_tensor)
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)