forked from mindspore-Ecosystem/mindspore
!29107 add sparse tensor gradient operations && tensor methods
Merge pull request !29107 from huangmengxi/csr_method
This commit is contained in:
commit
1c186ce115
|
@ -58,6 +58,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
||||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||
Register(prim::kPrimCSRReduceSum->name(), {1});
|
||||
Register(prim::kPrimCSRGather->name(), {3});
|
||||
Register(prim::kPrimCSR2COO->name(), {1});
|
||||
Register(prim::kPrimCOO2CSR->name(), {1});
|
||||
Register(kSparseGatherV2OpName, {2});
|
||||
Register(kUnsortedSegmentProdOpName, {2});
|
||||
Register(kSimpleMeanGradOpName, {1});
|
||||
|
|
|
@ -191,7 +191,6 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
|||
}
|
||||
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
|
||||
new_node->set_abstract(node->abstract());
|
||||
AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,8 @@ using MetaTensor = mindspore::tensor::MetaTensor;
|
|||
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
|
||||
using CSRTensor = mindspore::tensor::CSRTensor;
|
||||
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
|
||||
using COOTensor = mindspore::tensor::COOTensor;
|
||||
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
|
||||
|
||||
using InstanceCheckFunc = std::function<bool(const py::object &)>;
|
||||
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
|
||||
|
@ -489,6 +491,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
|||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
||||
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
||||
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
||||
|
|
|
@ -98,6 +98,7 @@ namespace pipeline {
|
|||
using Tensor = mindspore::tensor::Tensor;
|
||||
using MetaTensor = mindspore::tensor::MetaTensor;
|
||||
using CSRTensor = mindspore::tensor::CSRTensor;
|
||||
using COOTensor = mindspore::tensor::COOTensor;
|
||||
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTensorPtr;
|
||||
|
@ -178,7 +179,8 @@ bool CheckArgValid(const py::handle &arg) {
|
|||
|
||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
|
||||
py::isinstance<Number>(arg) ||
|
||||
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg)) && !py::hasattr(arg, "__parameter__"));
|
||||
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg)) &&
|
||||
!py::hasattr(arg, "__parameter__"));
|
||||
}
|
||||
|
||||
std::string GetCompileExceptionInfo() {
|
||||
|
|
|
@ -216,7 +216,17 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
{kObjectTypeEnvType, {}},
|
||||
{kObjectTypeCOOTensorType,
|
||||
{
|
||||
{"to_csr", std::string("coo_to_csr")},
|
||||
{"to_dense", std::string("coo_to_dense")},
|
||||
}},
|
||||
{kObjectTypeCSRTensorType,
|
||||
{
|
||||
{"to_coo", std::string("csr_to_coo")},
|
||||
{"to_dense", std::string("csr_to_dense")},
|
||||
}}};
|
||||
return method_map;
|
||||
}
|
||||
|
||||
|
|
|
@ -320,7 +320,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
|||
|
||||
bool IsCustomCSROP(const AnfNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV};
|
||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV,
|
||||
prim::kPrimCSRGather, prim::kPrimCSR2COO, prim::kPrimCOO2CSR};
|
||||
return IsOneOfPrimitiveCNode(cnode, prims);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -94,8 +94,13 @@ const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor},
|
|||
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
|
||||
// and (possibly) dense tensors, used in backend common optimization pass:
|
||||
// sparse_process.cc
|
||||
const mindspore::HashSet<std::string> sparse_op_set = {
|
||||
{prim::kSparseTensorDenseMatmul}, {prim::kCSRDenseMul}, {prim::kCSRReduceSum}, {prim::kCSRMV}, {prim::kCSRMul}};
|
||||
const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
|
||||
{prim::kCSRDenseMul},
|
||||
{prim::kCSRReduceSum},
|
||||
{prim::kCSRMV},
|
||||
{prim::kCSRMul},
|
||||
{prim::kCSRGather},
|
||||
{prim::kCSR2COO}};
|
||||
|
||||
bool IsCustomCSROP(const AnfNodePtr &cnode);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -165,6 +165,12 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -36,6 +36,7 @@ namespace abstract {
|
|||
constexpr auto kCSRDenseShape = "dense_shape";
|
||||
constexpr auto kCSRAxis = "axis";
|
||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||
constexpr auto kIsCSR = "is_csr";
|
||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// An object of a subclass of AbstractBase
|
||||
|
@ -439,6 +440,9 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
if (dense_shape[0] != sparse_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!";
|
||||
}
|
||||
auto ret = sparse->values()->Broaden();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
|
||||
|
@ -446,7 +450,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -482,7 +486,7 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -532,7 +536,98 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
|
||||
constexpr auto kCSRShapeSize = 2;
|
||||
constexpr auto kCSRArgsSize = 4;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
MS_EXCEPTION_IF_NULL(sparse_shape);
|
||||
|
||||
if (sparse_shape->size() != kCSRShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
|
||||
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
|
||||
}
|
||||
|
||||
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
ShapeVector out_shape = indices->shape()->shape();
|
||||
MS_EXCEPTION_IF_NULL(dense->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(nnz);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
|
||||
out_shape.push_back(nnz_value);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->shape());
|
||||
auto num_rows = indptr->shape()->shape()[0] - 1;
|
||||
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(row_indices);
|
||||
MS_EXCEPTION_IF_NULL(height);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(height->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t height_value = GetValue<int64_t>(height->BuildValue());
|
||||
out_shape.push_back(height_value + 1);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(row_indices->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -236,6 +236,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
|
||||
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
||||
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
||||
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
|
||||
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
|
||||
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
|
||||
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},
|
||||
|
|
|
@ -164,6 +164,9 @@ constexpr auto kCSRDenseMul = "CSRDenseMul";
|
|||
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
||||
constexpr auto kCSRMV = "CSRMV";
|
||||
constexpr auto kCSRMul = "CSRMul";
|
||||
constexpr auto kCSRGather = "CSRGather";
|
||||
constexpr auto kCSR2COO = "CSR2COO";
|
||||
constexpr auto kCOO2CSR = "COO2CSR";
|
||||
|
||||
// Meta Function Graph
|
||||
constexpr auto kJ = "J";
|
||||
|
@ -606,6 +609,9 @@ GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared<Primitive>(kCSRDenseMu
|
|||
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCOO2CSR, std::make_shared<Primitive>(kCOO2CSR));
|
||||
|
||||
// TensorList
|
||||
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -1573,6 +1573,32 @@ def while_cond(x):
|
|||
return x
|
||||
|
||||
|
||||
def coo_to_csr(x):
|
||||
row_indices = x.indices[:, 0]
|
||||
col_indices = x.indices[:, 1]
|
||||
idx_dtype = x.indices.dtype
|
||||
row_indices, sort_idx = F.sort(row_indices.astype(mstype.float32))
|
||||
row_indices = row_indices.astype(idx_dtype)
|
||||
col_indices = col_indices[sort_idx]
|
||||
values = x.values[sort_idx]
|
||||
indptr = F.coo2csr(row_indices, x.shape[0])
|
||||
return CSRTensor(indptr, col_indices, values, x.shape)
|
||||
|
||||
|
||||
def coo_to_dense(x):
|
||||
zeros_tensor = F.zeros(x.shape, x.values.dtype)
|
||||
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)
|
||||
|
||||
def csr_to_coo(x):
|
||||
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
|
||||
coo_indices = P.Stack(1)((row_indices, x.indices))
|
||||
return COOTensor(coo_indices, x.values, x.shape)
|
||||
|
||||
def csr_to_dense(x):
|
||||
coo_tensor = x.to_coo()
|
||||
return coo_tensor.to_dense()
|
||||
|
||||
|
||||
@constexpr
|
||||
def empty_tensor(dtype):
|
||||
return Tensor([], dtype)
|
||||
|
|
|
@ -290,7 +290,7 @@ class _MindsporeFunctionExecutor:
|
|||
return None
|
||||
new_inputs = []
|
||||
for i in args_list:
|
||||
if isinstance(i, (Tensor, CSRTensor)):
|
||||
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
new_inputs.append(i)
|
||||
|
|
|
@ -2468,6 +2468,24 @@ class COOTensor(COOTensor_):
|
|||
def shape(self):
|
||||
return self._shape
|
||||
|
||||
def to_csr(self):
|
||||
"Converts COOTensor to CSRTensor."
|
||||
row_indices = self.indices[:, 0]
|
||||
col_indices = self.indices[:, 1]
|
||||
idx_dtype = self.indices.dtype
|
||||
row_indices, sort_idx = tensor_operator_registry.get("sort")(
|
||||
row_indices.astype(mstype.float32))
|
||||
row_indices = row_indices.astype(idx_dtype)
|
||||
col_indices = col_indices[sort_idx]
|
||||
values = self.values[sort_idx]
|
||||
indptr = tensor_operator_registry.get("coo2csr")(row_indices, self.shape[0])
|
||||
return CSRTensor(indptr, col_indices, values, self.shape)
|
||||
|
||||
def to_dense(self):
|
||||
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
|
||||
return tensor_operator_registry.get("tensor_scatter_update")(
|
||||
zeros_tensor, self.indices, self.values)
|
||||
|
||||
|
||||
class CSRTensor(CSRTensor_):
|
||||
"""
|
||||
|
@ -2566,6 +2584,15 @@ class CSRTensor(CSRTensor_):
|
|||
def to_tuple(self):
|
||||
return self.indptr, self.indices, self.values, self.shape
|
||||
|
||||
def to_coo(self):
|
||||
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
|
||||
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
|
||||
return COOTensor(coo_indices, self.values, self.shape)
|
||||
|
||||
def to_dense(self):
|
||||
coo_tensor = self.to_coo()
|
||||
return coo_tensor.to_dense()
|
||||
|
||||
|
||||
def _vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
|
|
|
@ -32,7 +32,7 @@ from .._checkparam import Validator
|
|||
from ..common import dtype as mstype
|
||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
||||
from ..common.parameter import Parameter, ParameterTuple
|
||||
from ..common.tensor import Tensor, CSRTensor
|
||||
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||
from ..ops.operations import Cast
|
||||
from ..ops.primitive import Primitive
|
||||
from ..ops.operations import _inner_ops as inner
|
||||
|
@ -815,6 +815,8 @@ class Cell(Cell_):
|
|||
if i.has_init:
|
||||
i.init_data()
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, COOTensor):
|
||||
new_inputs.append(i)
|
||||
elif isinstance(i, CSRTensor):
|
||||
new_inputs.append(i)
|
||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||
|
@ -1267,16 +1269,16 @@ class Cell(Cell_):
|
|||
|
||||
def _add_mixed_precision_flag(self, **flags):
|
||||
"""Add mixed precision flag to current cell"""
|
||||
if "fp16" in flags and flags["fp16"]:
|
||||
if "fp16" in flags and flags.get("fp16", False):
|
||||
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
||||
if "fp32" in flags and flags["fp32"]:
|
||||
if "fp32" in flags and flags.get("fp32", False):
|
||||
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
||||
|
||||
def _add_mixed_precision_flag_recursive(self, **flags):
|
||||
"""Add mixed precision flag to each cell"""
|
||||
if "fp16" in flags and flags["fp16"]:
|
||||
if "fp16" in flags and flags.get("fp16", False):
|
||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
||||
if "fp32" in flags and flags["fp32"]:
|
||||
if "fp32" in flags and flags.get("fp32", False):
|
||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
||||
|
||||
def add_flags(self, **flags):
|
||||
|
@ -1876,15 +1878,16 @@ class Cell(Cell_):
|
|||
"""
|
||||
self._recompute()
|
||||
if 'mp_comm_recompute' in kwargs.keys():
|
||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
||||
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
||||
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
||||
if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1:
|
||||
if (kwargs.get('parallel_optimizer_comm_recompute', False) and
|
||||
context.get_auto_parallel_context("pipeline_stages") > 1):
|
||||
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||
"are not support recomputation in pipeline parallel.")
|
||||
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
||||
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
|
||||
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
|
||||
if 'recompute_slice_activation' in kwargs.keys():
|
||||
self._recompute_slice_activation(kwargs['recompute_slice_activation'])
|
||||
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
|
||||
|
||||
for key, _ in kwargs.items():
|
||||
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# ============================================================================
|
||||
|
||||
"""bprop primitives"""
|
||||
from ...common import dtype as mstype
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..operations import _csr_ops
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from .grad_base import bprops, bprop_getters
|
||||
|
||||
|
@ -78,3 +80,80 @@ def get_bprop_sparse_tensor_dense_matmul(self):
|
|||
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
||||
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
||||
def get_bprop_csr_reduce_sum(self):
|
||||
"Back-propagation for CSRReduceSum."
|
||||
def bprop(csr_tensor, axis, out, dout):
|
||||
indptr = csr_tensor.indptr
|
||||
indices = csr_tensor.indices
|
||||
shape = csr_tensor.shape
|
||||
|
||||
output_shape_kept_dims = F.reduced_shape(shape, axis)
|
||||
tile_scaling = F.tuple_div(shape, output_shape_kept_dims)
|
||||
values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling)
|
||||
values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape)
|
||||
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRMV)
|
||||
def get_bprop_csr_mv(self):
|
||||
"Back-propagation for CSRMV."
|
||||
def bprop(csr_tensor, dense, out, dout):
|
||||
indptr = F.csr_tensor_get_indptr(csr_tensor)
|
||||
indices = F.csr_tensor_get_indices(csr_tensor)
|
||||
values = F.csr_tensor_get_values(csr_tensor)
|
||||
dense_shape = csr_tensor.shape
|
||||
|
||||
rows = F.csr2coo(indptr, indices.shape[0])
|
||||
idx_dtype = rows.dtype
|
||||
rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32))
|
||||
rows_transposed = rows_transposed.astype(idx_dtype)
|
||||
cols_transposed = rows[cols_indexing]
|
||||
values_transposed = values[cols_indexing]
|
||||
indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1])
|
||||
csr_tensor_transposed = F.make_csr_tensor(
|
||||
indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0]))
|
||||
|
||||
dense_grad = F.csr_mv(csr_tensor_transposed, dout)
|
||||
parts_a = F.gather(dout, rows, 0)
|
||||
parts_b = F.gather(dense, indices, 0)
|
||||
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
||||
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSRMul)
|
||||
def get_bprop_csr_mul(self):
|
||||
"Back-propagation for CSRMul."
|
||||
def bprop(csr_tensor, dense, out, dout):
|
||||
indptr = csr_tensor.indptr
|
||||
indices = csr_tensor.indices
|
||||
values = csr_tensor.values
|
||||
shape = csr_tensor.shape
|
||||
|
||||
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
||||
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
|
||||
dense_grad_value = F.mul(dout, values)
|
||||
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
||||
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
||||
dense_grad = F.csr_reduce_sum(dense_grad, 0)
|
||||
elif dense.shape[1] == 1:
|
||||
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
||||
else:
|
||||
row = F.csr2coo(indptr, indices.shape[0])
|
||||
coo_idx = P.Stack(-1)((row, indices))
|
||||
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
|
||||
return csr_tensor_grad, dense_grad
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(_csr_ops.CSR2COO)
|
||||
def get_bprop_csr2coo(self):
|
||||
def bprop(indptr, nnz, out, dout):
|
||||
return zeros_like(dout)
|
||||
return bprop
|
||||
|
||||
@bprop_getters.register(_csr_ops.COO2CSR)
|
||||
def get_bprop_coo2csr(self):
|
||||
def bprop(row_indices, height, out, dout):
|
||||
return zeros_like(dout)
|
||||
return bprop
|
||||
|
|
|
@ -25,4 +25,7 @@ from .notequal import _notequal_akg
|
|||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||
from .csr_mv import _csr_mv_akg
|
||||
from .csr_mul import _csr_mul_akg
|
||||
from .csr_gather import _csr_gather_akg
|
||||
from .csr2coo import _csr2coo_akg
|
||||
from .coo2csr import _coo2csr_akg
|
||||
# Please insert op register in lexicographical order of the filename.
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""COO2CSR op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
coo2csr_op_info = AkgGpuRegOp("COO2CSR") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "row_indices") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(coo2csr_op_info)
|
||||
def _coo2csr_akg():
|
||||
"""COO2CSR AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSR2COO op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
csr2coo_op_info = AkgGpuRegOp("CSR2COO") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr2coo_op_info)
|
||||
def _csr2coo_akg():
|
||||
"""CSR2COO AutoDiff register"""
|
||||
return
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""CSRGatherop"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||
|
||||
csr_gather_op_info = AkgGpuRegOp("CSRGather") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "indptr") \
|
||||
.input(1, "indices") \
|
||||
.input(2, "dense") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
@op_info_register(csr_gather_op_info)
|
||||
def _csr_gather_akg():
|
||||
"""CSRGather AutoDiff register"""
|
||||
return
|
|
@ -22,9 +22,6 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
|
|||
.input(2, "values") \
|
||||
.input(4, "dense_tensor") \
|
||||
.output(0, "output") \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||
DataType.F32_Default, \
|
||||
DataType.F32_Default) \
|
||||
|
|
|
@ -346,4 +346,19 @@ def _add_nonetensor_tensor(x, y):
|
|||
"""
|
||||
return x + y
|
||||
|
||||
|
||||
@_add_backward.register("CSRTensor", "CSRTensor")
|
||||
def _add_csrtensor_csrtensor(x, y):
|
||||
"""
|
||||
Adds CSRTensor and CSRTensor.
|
||||
|
||||
Args:
|
||||
x (CSRTensor): x
|
||||
y (CSRTensor): y
|
||||
|
||||
Returns:
|
||||
CSRTensor.
|
||||
"""
|
||||
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
|
||||
|
||||
hyper_add = base.HyperMap(_add_backward)
|
||||
|
|
|
@ -58,6 +58,12 @@ def _ones_like_coo_tensor(x):
|
|||
return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x))
|
||||
|
||||
|
||||
@ones_like_leaf.register("CSRTensor")
|
||||
def _ones_like_csr_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||
return F.make_csr_tensor(x.indptr, x.indices, ones_like(x.values), x.shape)
|
||||
|
||||
|
||||
@ones_like_leaf.register("Function")
|
||||
def _ones_like_func(x):
|
||||
"""
|
||||
|
|
|
@ -58,6 +58,20 @@ def _zeros_like_tensor(x):
|
|||
return F.zeros_like(x)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("COOTensor")
|
||||
def _zeros_like_coo_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||
values = F.zeros_like(x.values)
|
||||
return F.make_coo_tensor(x.indices, values, x.shape)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("CSRTensor")
|
||||
def _zeros_like_csr_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||
values = F.zeros_like(x.values)
|
||||
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("TypeType")
|
||||
def _zeros_like_type_type(x):
|
||||
"""Returns x because x is a type. This is usually used in backprop progress."""
|
||||
|
|
|
@ -152,6 +152,9 @@ stack = P.Stack()
|
|||
csr_mul = _csr_ops.CSRMul()
|
||||
csr_mv = _csr_ops.CSRMV()
|
||||
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
||||
csr_gather = _csr_ops.CSRGather()
|
||||
csr2coo = _csr_ops.CSR2COO()
|
||||
coo2csr = _csr_ops.COO2CSR()
|
||||
|
||||
_select = P.Select()
|
||||
|
||||
|
@ -576,6 +579,7 @@ not_in_dict = Primitive("not_in_dict")
|
|||
mixed_precision_cast = Primitive("mixed_precision_cast")
|
||||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
zeros = P.Zeros()
|
||||
zeros_like = P.ZerosLike()
|
||||
distribute = Primitive('distribute')
|
||||
embed = Primitive('embed')
|
||||
|
@ -670,6 +674,11 @@ tensor_operator_registry.register('log', log)
|
|||
tensor_operator_registry.register('floor', floor)
|
||||
# support sparse tensor operators
|
||||
tensor_operator_registry.register('csr_mul', csr_mul)
|
||||
tensor_operator_registry.register('csr2coo', csr2coo)
|
||||
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||
tensor_operator_registry.register('narrow', narrow)
|
||||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('zeros', zeros)
|
||||
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
|
||||
__all__ = [name for name in dir() if name[0] != "_"]
|
||||
__all__.remove('Primitive')
|
||||
|
|
|
@ -20,6 +20,9 @@ class CSRReduceSum(PrimitiveWithInfer):
|
|||
"""
|
||||
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
||||
- **axis** (int) - The dimensions to reduce.
|
||||
|
@ -64,6 +67,9 @@ class CSRMV(PrimitiveWithInfer):
|
|||
"""
|
||||
Sparse matrix-vector multiplication.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
||||
- **dense_tensor** (Tensor) - A dense Tensor.
|
||||
|
@ -109,6 +115,9 @@ class CSRMul(PrimitiveWithInfer):
|
|||
"""
|
||||
Elemwise multiplication on a CSRTensor and a dense tensor.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Note:
|
||||
The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`.
|
||||
If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor.
|
||||
|
@ -151,3 +160,129 @@ class CSRMul(PrimitiveWithInfer):
|
|||
"""Initialize CSRMul"""
|
||||
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
|
||||
outputs=['output'])
|
||||
|
||||
|
||||
class CSRGather(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **indptr** (Tensor) - A Tensor.
|
||||
- **indices** (Tensor) - A Tensor.
|
||||
- **dense** (Tensor) - A Tensor.
|
||||
- **sparse_shape** (tuple) - A tuple of integers.
|
||||
|
||||
Outputs:
|
||||
Tensor, the dtype is the same as `dense`, the first dimension is the same shape as `indices` and the remaining
|
||||
dimensions are the same as ``dense[2:]``.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.op = ops.CSRGather()
|
||||
...
|
||||
... def construct(self, indptr, indices, dense, sparse_shape):
|
||||
... return self.op(indptr, indices, dense, sparse_shape)
|
||||
>>> indptr = Tensor([0, 1, 2])
|
||||
>>> indices = Tensor([0, 1])
|
||||
>>> sparse_shape = (2, 4)
|
||||
>>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||
>>> out = Net()(indptr, indices, dense, sparse_shape)
|
||||
>>> print(out)
|
||||
[1. 1.]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize CSRGather"""
|
||||
self.init_prim_io_names(inputs=['indptr', 'indices', 'dense', 'dense_shape'],
|
||||
outputs=['output'])
|
||||
|
||||
|
||||
class CSR2COO(PrimitiveWithInfer):
|
||||
"""
|
||||
Converts the indptr of a CSRTensor to the row indices of a COOTensor.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **indptr** (Tensor) - A Tensor.
|
||||
- **nnz** (int) - Denotes the number of non-zero elements in the sparse tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the dtype is the same as `indptr` and has shape (`nnz`,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.op = ops.CSR2COO()
|
||||
...
|
||||
... def construct(self, indptr, nnz):
|
||||
... return self.op(indptr, nnz)
|
||||
>>> indptr = Tensor([0, 1, 2])
|
||||
>>> out = Net()(indptr, 2)
|
||||
>>> print(out)
|
||||
[1 1]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize CSR2COO"""
|
||||
self.init_prim_io_names(inputs=['indptr', 'nnz'], outputs=['output'])
|
||||
|
||||
|
||||
class COO2CSR(PrimitiveWithInfer):
|
||||
"""
|
||||
Converts the row indices of a COOTensor to the indptr of a CSRTensor.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Inputs:
|
||||
- **row_indices** (Tensor) - A Tensor.
|
||||
- **height** (int) - the height of the first dimension of the sparse tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,).
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.op = ops.COO2CSR()
|
||||
...
|
||||
... def construct(self, row_indices, height):
|
||||
... return self.op(row_indices, height)
|
||||
>>> row_indices = Tensor([0, 1])
|
||||
>>> out = Net()(row_indices, 2)
|
||||
>>> print(out)
|
||||
[0 1 2]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize COO2CSR"""
|
||||
self.init_prim_io_names(inputs=['row_indices', 'height'], outputs=['output'])
|
||||
|
|
|
@ -91,3 +91,39 @@ def test_coo_tensor_in_while():
|
|||
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
|
||||
assert out.shape == shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_coo_method():
|
||||
"""
|
||||
Feature: Test coo tensor methods.
|
||||
Description: Test coo_tensor.to_csr(), coo_tensor.to_dense().
|
||||
Expectation: Success.
|
||||
"""
|
||||
class COOToCSRNet(nn.Cell):
|
||||
def construct(self, coo_tensor):
|
||||
return coo_tensor.to_csr()
|
||||
|
||||
class COOToDenseNet(nn.Cell):
|
||||
def construct(self, coo_tensor):
|
||||
return coo_tensor.to_dense()
|
||||
|
||||
indices = Tensor([[1, 2], [0, 1]], dtype=mstype.int32)
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
shape = (3, 4)
|
||||
coo_tensor = COOTensor(indices, values, shape)
|
||||
|
||||
to_csr_output = COOToCSRNet()(coo_tensor)
|
||||
to_csr_expect_1 = np.array([0, 1, 2, 2], dtype=np.int32)
|
||||
to_csr_expect_2 = np.array([1, 2], dtype=np.int32)
|
||||
to_csr_expect_3 = np.array([1, 2], dtype=np.float32)
|
||||
assert np.allclose(to_csr_output.indptr.asnumpy(), to_csr_expect_1)
|
||||
assert np.allclose(to_csr_output.indices.asnumpy(), to_csr_expect_2)
|
||||
assert np.allclose(to_csr_output.values.asnumpy(), to_csr_expect_3)
|
||||
|
||||
to_dense_output = COOToDenseNet()(coo_tensor)
|
||||
to_dense_expect = np.array(
|
||||
[[0., 1., 0., 0.], [0., 0., 2., 0.], [0., 0., 0., 0.]], dtype=np.float32)
|
||||
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
|
||||
|
|
|
@ -18,7 +18,7 @@ import os
|
|||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, CSRTensor, ms_function, nn, context
|
||||
from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
|
||||
from mindspore.ops.operations import _csr_ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.serialization import export, load
|
||||
|
@ -232,8 +232,8 @@ def test_csr_ops():
|
|||
csr_reducesum = _csr_ops.CSRReduceSum()
|
||||
csrmv = _csr_ops.CSRMV()
|
||||
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
dense_shape = (2, 4)
|
||||
|
||||
|
@ -331,8 +331,8 @@ def test_csrops_export_and_import_mindir():
|
|||
sparse2 = dence_tensor * csr_tensor
|
||||
return dense1, dense2, dense3, sparse1, sparse2
|
||||
|
||||
indptr = Tensor([0, 1, 2])
|
||||
indices = Tensor([0, 1])
|
||||
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||
values = Tensor([2, 1], dtype=mstype.float32)
|
||||
shape = (2, 4)
|
||||
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||
|
@ -428,3 +428,111 @@ def test_dtype_csr_tensor():
|
|||
out2 = graph_test()
|
||||
assert out1 in [mstype.float32]
|
||||
assert out2 in [mstype.float32]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_bprop():
|
||||
"""
|
||||
Feature: Test back-propagation with CSR-related Ops.
|
||||
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMulNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMul()
|
||||
|
||||
def construct(self, csr_tensor, dense):
|
||||
return self.op(csr_tensor, dense)
|
||||
|
||||
class CSRReduceSumNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRReduceSumNet, self).__init__()
|
||||
self.op = _csr_ops.CSRReduceSum()
|
||||
|
||||
def construct(self, csr_tensor, axis):
|
||||
return self.op(csr_tensor, axis)
|
||||
|
||||
class CSRMVNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CSRMVNet, self).__init__()
|
||||
self.op = _csr_ops.CSRMV()
|
||||
|
||||
def construct(self, csr_tensor, dense):
|
||||
return self.op(csr_tensor, dense)
|
||||
|
||||
class BpropNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(BpropNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_op = ops.GradOperation(get_all=True)
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.grad_op(self.net)(*inputs)
|
||||
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(6), dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
|
||||
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
|
||||
csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg)
|
||||
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
|
||||
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
|
||||
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||
|
||||
csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1)
|
||||
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
|
||||
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||
|
||||
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
|
||||
csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1)
|
||||
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
|
||||
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
|
||||
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||
|
||||
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
|
||||
csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2)
|
||||
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
|
||||
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_method():
|
||||
"""
|
||||
Feature: Test csr tensor methods.
|
||||
Description: Test csr_tensor.to_coo(), csr_tensor.to_dense().
|
||||
Expectation: Success.
|
||||
"""
|
||||
class CSRToCOONet(nn.Cell):
|
||||
def construct(self, csr_tensor):
|
||||
return csr_tensor.to_coo()
|
||||
|
||||
class CSRToDenseNet(nn.Cell):
|
||||
def construct(self, csr_tensor):
|
||||
return csr_tensor.to_dense()
|
||||
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(6), dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||
|
||||
to_coo_output = CSRToCOONet()(csr_tensor)
|
||||
to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32)
|
||||
to_coo_expect_2 = np.arange(6).astype(np.float32)
|
||||
assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1)
|
||||
assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2)
|
||||
|
||||
to_dense_output = CSRToDenseNet()(csr_tensor)
|
||||
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
|
||||
|
|
Loading…
Reference in New Issue