add csr bprop && csr method
This commit is contained in:
parent
3dea54b28d
commit
080ad981d6
|
@ -58,6 +58,9 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
||||||
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
Register(prim::kPrimUnsortedSegmentMin->name(), {2});
|
||||||
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
Register(prim::kPrimUnsortedSegmentMax->name(), {2});
|
||||||
Register(prim::kPrimCSRReduceSum->name(), {1});
|
Register(prim::kPrimCSRReduceSum->name(), {1});
|
||||||
|
Register(prim::kPrimCSRGather->name(), {3});
|
||||||
|
Register(prim::kPrimCSR2COO->name(), {1});
|
||||||
|
Register(prim::kPrimCOO2CSR->name(), {1});
|
||||||
Register(kSparseGatherV2OpName, {2});
|
Register(kSparseGatherV2OpName, {2});
|
||||||
Register(kUnsortedSegmentProdOpName, {2});
|
Register(kUnsortedSegmentProdOpName, {2});
|
||||||
Register(kSimpleMeanGradOpName, {1});
|
Register(kSimpleMeanGradOpName, {1});
|
||||||
|
|
|
@ -191,7 +191,6 @@ const AnfNodePtr SparseProcess::Process(const FuncGraphPtr &func_graph, const An
|
||||||
}
|
}
|
||||||
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
|
auto new_node = cnode->func_graph()->NewCNode(new_inputs);
|
||||||
new_node->set_abstract(node->abstract());
|
new_node->set_abstract(node->abstract());
|
||||||
AnfAlgo::SetNodeAttr("is_csr", MakeValue(true), new_node);
|
|
||||||
return new_node;
|
return new_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,8 @@ using MetaTensor = mindspore::tensor::MetaTensor;
|
||||||
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
|
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
|
||||||
using CSRTensor = mindspore::tensor::CSRTensor;
|
using CSRTensor = mindspore::tensor::CSRTensor;
|
||||||
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
|
using CSRTensorPtr = mindspore::tensor::CSRTensorPtr;
|
||||||
|
using COOTensor = mindspore::tensor::COOTensor;
|
||||||
|
using COOTensorPtr = mindspore::tensor::COOTensorPtr;
|
||||||
|
|
||||||
using InstanceCheckFunc = std::function<bool(const py::object &)>;
|
using InstanceCheckFunc = std::function<bool(const py::object &)>;
|
||||||
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
|
using InstanceConvertFunc = std::function<ValuePtr(const py::object &, bool, const TypePtr &)>;
|
||||||
|
@ -489,6 +491,7 @@ std::vector<DataConverterPtr> GetDataConverters() {
|
||||||
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
std::make_shared<ByTypeDataConverter<Tensor>>(ObjCast<TensorPtr>),
|
||||||
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
std::make_shared<ByTypeDataConverter<MetaTensor>>(ObjCast<MetaTensorPtr>),
|
||||||
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
std::make_shared<ByTypeDataConverter<CSRTensor>>(ObjCast<CSRTensorPtr>),
|
||||||
|
std::make_shared<ByTypeDataConverter<COOTensor>>(ObjCast<COOTensorPtr>),
|
||||||
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
std::make_shared<ByTypeDataConverter<py::tuple>>(ConvertTuple),
|
||||||
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
std::make_shared<ByTypeDataConverter<py::list>>(ConvertList),
|
||||||
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
std::make_shared<ByTypeDataConverter<py::bool_>>(PyCast<BoolImm, bool>),
|
||||||
|
|
|
@ -98,6 +98,7 @@ namespace pipeline {
|
||||||
using Tensor = mindspore::tensor::Tensor;
|
using Tensor = mindspore::tensor::Tensor;
|
||||||
using MetaTensor = mindspore::tensor::MetaTensor;
|
using MetaTensor = mindspore::tensor::MetaTensor;
|
||||||
using CSRTensor = mindspore::tensor::CSRTensor;
|
using CSRTensor = mindspore::tensor::CSRTensor;
|
||||||
|
using COOTensor = mindspore::tensor::COOTensor;
|
||||||
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
|
using TensorOrderMap = std::map<std::string, std::shared_ptr<Tensor>>;
|
||||||
using mindspore::abstract::AbstractTensor;
|
using mindspore::abstract::AbstractTensor;
|
||||||
using mindspore::abstract::AbstractTensorPtr;
|
using mindspore::abstract::AbstractTensorPtr;
|
||||||
|
@ -178,7 +179,8 @@ bool CheckArgValid(const py::handle &arg) {
|
||||||
|
|
||||||
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
|
return py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) || py::isinstance<py::none>(arg) ||
|
||||||
py::isinstance<Number>(arg) ||
|
py::isinstance<Number>(arg) ||
|
||||||
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg)) && !py::hasattr(arg, "__parameter__"));
|
((py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg)) &&
|
||||||
|
!py::hasattr(arg, "__parameter__"));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetCompileExceptionInfo() {
|
std::string GetCompileExceptionInfo() {
|
||||||
|
|
|
@ -216,7 +216,17 @@ BuiltInTypeMap &GetMethodMap() {
|
||||||
}},
|
}},
|
||||||
{kObjectTypeJTagged, {}},
|
{kObjectTypeJTagged, {}},
|
||||||
{kObjectTypeSymbolicKeyType, {}},
|
{kObjectTypeSymbolicKeyType, {}},
|
||||||
{kObjectTypeEnvType, {}}};
|
{kObjectTypeEnvType, {}},
|
||||||
|
{kObjectTypeCOOTensorType,
|
||||||
|
{
|
||||||
|
{"to_csr", std::string("coo_to_csr")},
|
||||||
|
{"to_dense", std::string("coo_to_dense")},
|
||||||
|
}},
|
||||||
|
{kObjectTypeCSRTensorType,
|
||||||
|
{
|
||||||
|
{"to_coo", std::string("csr_to_coo")},
|
||||||
|
{"to_dense", std::string("csr_to_dense")},
|
||||||
|
}}};
|
||||||
return method_map;
|
return method_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -320,7 +320,8 @@ size_t CountValueNum(const ValueTuplePtr &value_tuple) {
|
||||||
|
|
||||||
bool IsCustomCSROP(const AnfNodePtr &cnode) {
|
bool IsCustomCSROP(const AnfNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV};
|
const PrimitiveSet prims{prim::kPrimCSRReduceSum, prim::kPrimCSRMul, prim::kPrimCSRMV,
|
||||||
|
prim::kPrimCSRGather, prim::kPrimCSR2COO, prim::kPrimCOO2CSR};
|
||||||
return IsOneOfPrimitiveCNode(cnode, prims);
|
return IsOneOfPrimitiveCNode(cnode, prims);
|
||||||
}
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -94,8 +94,13 @@ const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor},
|
||||||
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
|
// sparse_op_set records all sparse_compute operators, which takes sparsetensor
|
||||||
// and (possibly) dense tensors, used in backend common optimization pass:
|
// and (possibly) dense tensors, used in backend common optimization pass:
|
||||||
// sparse_process.cc
|
// sparse_process.cc
|
||||||
const mindspore::HashSet<std::string> sparse_op_set = {
|
const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
|
||||||
{prim::kSparseTensorDenseMatmul}, {prim::kCSRDenseMul}, {prim::kCSRReduceSum}, {prim::kCSRMV}, {prim::kCSRMul}};
|
{prim::kCSRDenseMul},
|
||||||
|
{prim::kCSRReduceSum},
|
||||||
|
{prim::kCSRMV},
|
||||||
|
{prim::kCSRMul},
|
||||||
|
{prim::kCSRGather},
|
||||||
|
{prim::kCSR2COO}};
|
||||||
|
|
||||||
bool IsCustomCSROP(const AnfNodePtr &cnode);
|
bool IsCustomCSROP(const AnfNodePtr &cnode);
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -165,6 +165,12 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplIsCSRFunc(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
|
|
@ -36,6 +36,7 @@ namespace abstract {
|
||||||
constexpr auto kCSRDenseShape = "dense_shape";
|
constexpr auto kCSRDenseShape = "dense_shape";
|
||||||
constexpr auto kCSRAxis = "axis";
|
constexpr auto kCSRAxis = "axis";
|
||||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||||
|
constexpr auto kIsCSR = "is_csr";
|
||||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
// An object of a subclass of AbstractBase
|
// An object of a subclass of AbstractBase
|
||||||
|
@ -439,6 +440,9 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
<< "but sparse tensor has " << sparse_shape.size() << " dimensions, "
|
||||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||||
}
|
}
|
||||||
|
if (dense_shape[0] != sparse_shape[0]) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast with last dim!";
|
||||||
|
}
|
||||||
auto ret = sparse->values()->Broaden();
|
auto ret = sparse->values()->Broaden();
|
||||||
|
|
||||||
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
|
MS_EXCEPTION_IF_NULL(sparse->indices()->shape());
|
||||||
|
@ -446,7 +450,7 @@ AbstractBasePtr InferImplCSRMul(const AnalysisEnginePtr &, const PrimitivePtr &p
|
||||||
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
||||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||||
|
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -482,7 +486,7 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
||||||
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
int csr_avg_rows = SizeToInt(nnz_vec[0] / dense_shape[0]);
|
||||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||||
|
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -532,7 +536,98 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
||||||
int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]);
|
int csr_avg_rows = SizeToInt(nnz_vec[0] / sparse_shape[0]);
|
||||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||||
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
primitive->set_attr(kCSRDenseShape, MakeValue(sparse_shape));
|
||||||
|
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
|
||||||
|
constexpr auto kCSRShapeSize = 2;
|
||||||
|
constexpr auto kCSRArgsSize = 4;
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||||
|
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||||
|
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||||
|
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||||
|
MS_EXCEPTION_IF_NULL(indptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(indices);
|
||||||
|
MS_EXCEPTION_IF_NULL(dense);
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse_shape);
|
||||||
|
|
||||||
|
if (sparse_shape->size() != kCSRShapeSize) {
|
||||||
|
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
|
||||||
|
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_value);
|
||||||
|
auto nnz_vec = indices->shape()->shape();
|
||||||
|
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
|
||||||
|
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||||
|
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||||
|
ShapeVector out_shape = indices->shape()->shape();
|
||||||
|
MS_EXCEPTION_IF_NULL(dense->element());
|
||||||
|
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
|
||||||
|
constexpr auto kCSRArgsSize = 2;
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||||
|
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(indptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(nnz);
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
|
||||||
|
ShapeVector out_shape;
|
||||||
|
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
|
||||||
|
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
|
||||||
|
out_shape.push_back(nnz_value);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(indptr->shape());
|
||||||
|
auto num_rows = indptr->shape()->shape()[0] - 1;
|
||||||
|
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
|
||||||
|
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||||
|
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(indptr->element());
|
||||||
|
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
|
||||||
|
constexpr auto kCSRArgsSize = 2;
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||||
|
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(row_indices);
|
||||||
|
MS_EXCEPTION_IF_NULL(height);
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(height->BuildValue());
|
||||||
|
ShapeVector out_shape;
|
||||||
|
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
|
||||||
|
int64_t height_value = GetValue<int64_t>(height->BuildValue());
|
||||||
|
out_shape.push_back(height_value + 1);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(row_indices->element());
|
||||||
|
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -236,6 +236,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
|
{prim::kPrimCSRMul, R{InferImplCSRMul, nullptr, true}},
|
||||||
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
||||||
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
||||||
|
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
|
||||||
|
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
|
||||||
|
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
|
||||||
// Comm Ops
|
// Comm Ops
|
||||||
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
|
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
|
||||||
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},
|
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},
|
||||||
|
|
|
@ -164,6 +164,9 @@ constexpr auto kCSRDenseMul = "CSRDenseMul";
|
||||||
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
||||||
constexpr auto kCSRMV = "CSRMV";
|
constexpr auto kCSRMV = "CSRMV";
|
||||||
constexpr auto kCSRMul = "CSRMul";
|
constexpr auto kCSRMul = "CSRMul";
|
||||||
|
constexpr auto kCSRGather = "CSRGather";
|
||||||
|
constexpr auto kCSR2COO = "CSR2COO";
|
||||||
|
constexpr auto kCOO2CSR = "COO2CSR";
|
||||||
|
|
||||||
// Meta Function Graph
|
// Meta Function Graph
|
||||||
constexpr auto kJ = "J";
|
constexpr auto kJ = "J";
|
||||||
|
@ -606,6 +609,9 @@ GVAR_DEF(PrimitivePtr, kPrimCSRDenseMul, std::make_shared<Primitive>(kCSRDenseMu
|
||||||
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
|
GVAR_DEF(PrimitivePtr, kPrimCSRMul, std::make_shared<Primitive>(kCSRMul));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimCSRGather, std::make_shared<Primitive>(kCSRGather));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimCSR2COO, std::make_shared<Primitive>(kCSR2COO));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimCOO2CSR, std::make_shared<Primitive>(kCOO2CSR));
|
||||||
|
|
||||||
// TensorList
|
// TensorList
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));
|
GVAR_DEF(PrimitivePtr, kPrimTensorListFromTensor, std::make_shared<Primitive>("TensorListFromTensor"));
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter, CSRTensor, COOTensor
|
||||||
from mindspore import dtype as mstype
|
from mindspore import dtype as mstype
|
||||||
|
|
||||||
from ..._checkparam import Validator as validator
|
from ..._checkparam import Validator as validator
|
||||||
|
@ -1573,6 +1573,32 @@ def while_cond(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def coo_to_csr(x):
|
||||||
|
row_indices = x.indices[:, 0]
|
||||||
|
col_indices = x.indices[:, 1]
|
||||||
|
idx_dtype = x.indices.dtype
|
||||||
|
row_indices, sort_idx = F.sort(row_indices.astype(mstype.float32))
|
||||||
|
row_indices = row_indices.astype(idx_dtype)
|
||||||
|
col_indices = col_indices[sort_idx]
|
||||||
|
values = x.values[sort_idx]
|
||||||
|
indptr = F.coo2csr(row_indices, x.shape[0])
|
||||||
|
return CSRTensor(indptr, col_indices, values, x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def coo_to_dense(x):
|
||||||
|
zeros_tensor = F.zeros(x.shape, x.values.dtype)
|
||||||
|
return F.tensor_scatter_update(zeros_tensor, x.indices, x.values)
|
||||||
|
|
||||||
|
def csr_to_coo(x):
|
||||||
|
row_indices = F.csr2coo(x.indptr, x.values.shape[0])
|
||||||
|
coo_indices = P.Stack(1)((row_indices, x.indices))
|
||||||
|
return COOTensor(coo_indices, x.values, x.shape)
|
||||||
|
|
||||||
|
def csr_to_dense(x):
|
||||||
|
coo_tensor = x.to_coo()
|
||||||
|
return coo_tensor.to_dense()
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def empty_tensor(dtype):
|
def empty_tensor(dtype):
|
||||||
return Tensor([], dtype)
|
return Tensor([], dtype)
|
||||||
|
|
|
@ -290,7 +290,7 @@ class _MindsporeFunctionExecutor:
|
||||||
return None
|
return None
|
||||||
new_inputs = []
|
new_inputs = []
|
||||||
for i in args_list:
|
for i in args_list:
|
||||||
if isinstance(i, (Tensor, CSRTensor)):
|
if isinstance(i, (Tensor, CSRTensor, COOTensor)):
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
|
|
|
@ -2468,6 +2468,24 @@ class COOTensor(COOTensor_):
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
|
def to_csr(self):
|
||||||
|
"Converts COOTensor to CSRTensor."
|
||||||
|
row_indices = self.indices[:, 0]
|
||||||
|
col_indices = self.indices[:, 1]
|
||||||
|
idx_dtype = self.indices.dtype
|
||||||
|
row_indices, sort_idx = tensor_operator_registry.get("sort")(
|
||||||
|
row_indices.astype(mstype.float32))
|
||||||
|
row_indices = row_indices.astype(idx_dtype)
|
||||||
|
col_indices = col_indices[sort_idx]
|
||||||
|
values = self.values[sort_idx]
|
||||||
|
indptr = tensor_operator_registry.get("coo2csr")(row_indices, self.shape[0])
|
||||||
|
return CSRTensor(indptr, col_indices, values, self.shape)
|
||||||
|
|
||||||
|
def to_dense(self):
|
||||||
|
zeros_tensor = tensor_operator_registry.get("zeros")(self.shape, self.values.dtype)
|
||||||
|
return tensor_operator_registry.get("tensor_scatter_update")(
|
||||||
|
zeros_tensor, self.indices, self.values)
|
||||||
|
|
||||||
|
|
||||||
class CSRTensor(CSRTensor_):
|
class CSRTensor(CSRTensor_):
|
||||||
"""
|
"""
|
||||||
|
@ -2566,6 +2584,15 @@ class CSRTensor(CSRTensor_):
|
||||||
def to_tuple(self):
|
def to_tuple(self):
|
||||||
return self.indptr, self.indices, self.values, self.shape
|
return self.indptr, self.indices, self.values, self.shape
|
||||||
|
|
||||||
|
def to_coo(self):
|
||||||
|
row_indices = tensor_operator_registry.get("csr2coo")(self.indptr, self.values.shape[0])
|
||||||
|
coo_indices = tensor_operator_registry.get("stack")(1)((row_indices, self.indices))
|
||||||
|
return COOTensor(coo_indices, self.values, self.shape)
|
||||||
|
|
||||||
|
def to_dense(self):
|
||||||
|
coo_tensor = self.to_coo()
|
||||||
|
return coo_tensor.to_dense()
|
||||||
|
|
||||||
|
|
||||||
def _vm_compare(*args):
|
def _vm_compare(*args):
|
||||||
"""Implement `vm_compare` for tensor."""
|
"""Implement `vm_compare` for tensor."""
|
||||||
|
|
|
@ -32,7 +32,7 @@ from .._checkparam import Validator
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
|
||||||
from ..common.parameter import Parameter, ParameterTuple
|
from ..common.parameter import Parameter, ParameterTuple
|
||||||
from ..common.tensor import Tensor, CSRTensor
|
from ..common.tensor import Tensor, CSRTensor, COOTensor
|
||||||
from ..ops.operations import Cast
|
from ..ops.operations import Cast
|
||||||
from ..ops.primitive import Primitive
|
from ..ops.primitive import Primitive
|
||||||
from ..ops.operations import _inner_ops as inner
|
from ..ops.operations import _inner_ops as inner
|
||||||
|
@ -815,6 +815,8 @@ class Cell(Cell_):
|
||||||
if i.has_init:
|
if i.has_init:
|
||||||
i.init_data()
|
i.init_data()
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
|
elif isinstance(i, COOTensor):
|
||||||
|
new_inputs.append(i)
|
||||||
elif isinstance(i, CSRTensor):
|
elif isinstance(i, CSRTensor):
|
||||||
new_inputs.append(i)
|
new_inputs.append(i)
|
||||||
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
|
||||||
|
@ -1267,16 +1269,16 @@ class Cell(Cell_):
|
||||||
|
|
||||||
def _add_mixed_precision_flag(self, **flags):
|
def _add_mixed_precision_flag(self, **flags):
|
||||||
"""Add mixed precision flag to current cell"""
|
"""Add mixed precision flag to current cell"""
|
||||||
if "fp16" in flags and flags["fp16"]:
|
if "fp16" in flags and flags.get("fp16", False):
|
||||||
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP16)
|
||||||
if "fp32" in flags and flags["fp32"]:
|
if "fp32" in flags and flags.get("fp32", False):
|
||||||
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
Cell_.set_mixed_precision_type(self, MixedPrecisionType.FP32)
|
||||||
|
|
||||||
def _add_mixed_precision_flag_recursive(self, **flags):
|
def _add_mixed_precision_flag_recursive(self, **flags):
|
||||||
"""Add mixed precision flag to each cell"""
|
"""Add mixed precision flag to each cell"""
|
||||||
if "fp16" in flags and flags["fp16"]:
|
if "fp16" in flags and flags.get("fp16", False):
|
||||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP16)
|
||||||
if "fp32" in flags and flags["fp32"]:
|
if "fp32" in flags and flags.get("fp32", False):
|
||||||
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
self._set_mixed_precision_type_recursive(MixedPrecisionType.FP32)
|
||||||
|
|
||||||
def add_flags(self, **flags):
|
def add_flags(self, **flags):
|
||||||
|
@ -1876,15 +1878,16 @@ class Cell(Cell_):
|
||||||
"""
|
"""
|
||||||
self._recompute()
|
self._recompute()
|
||||||
if 'mp_comm_recompute' in kwargs.keys():
|
if 'mp_comm_recompute' in kwargs.keys():
|
||||||
self._mp_comm_recompute(kwargs['mp_comm_recompute'])
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
||||||
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
if 'parallel_optimizer_comm_recompute' in kwargs.keys():
|
||||||
if kwargs['parallel_optimizer_comm_recompute'] and context.get_auto_parallel_context("pipeline_stages") > 1:
|
if (kwargs.get('parallel_optimizer_comm_recompute', False) and
|
||||||
|
context.get_auto_parallel_context("pipeline_stages") > 1):
|
||||||
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
logger.warning("Currently, the communication operator allgathers introduced by optimizer shard "
|
||||||
"are not support recomputation in pipeline parallel.")
|
"are not support recomputation in pipeline parallel.")
|
||||||
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
elif context.get_auto_parallel_context("pipeline_stages") == 1:
|
||||||
self._parallel_optimizer_comm_recompute(kwargs['parallel_optimizer_comm_recompute'])
|
self._parallel_optimizer_comm_recompute(kwargs.get('parallel_optimizer_comm_recompute', False))
|
||||||
if 'recompute_slice_activation' in kwargs.keys():
|
if 'recompute_slice_activation' in kwargs.keys():
|
||||||
self._recompute_slice_activation(kwargs['recompute_slice_activation'])
|
self._recompute_slice_activation(kwargs.get('recompute_slice_activation', False))
|
||||||
|
|
||||||
for key, _ in kwargs.items():
|
for key, _ in kwargs.items():
|
||||||
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
|
if key not in ('mp_comm_recompute', 'parallel_optimizer_comm_recompute', 'recompute_slice_activation'):
|
||||||
|
|
|
@ -14,8 +14,10 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""bprop primitives"""
|
"""bprop primitives"""
|
||||||
|
from ...common import dtype as mstype
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
|
from ..operations import _csr_ops
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from .grad_base import bprops, bprop_getters
|
from .grad_base import bprops, bprop_getters
|
||||||
|
|
||||||
|
@ -78,3 +80,80 @@ def get_bprop_sparse_tensor_dense_matmul(self):
|
||||||
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
||||||
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(_csr_ops.CSRReduceSum)
|
||||||
|
def get_bprop_csr_reduce_sum(self):
|
||||||
|
"Back-propagation for CSRReduceSum."
|
||||||
|
def bprop(csr_tensor, axis, out, dout):
|
||||||
|
indptr = csr_tensor.indptr
|
||||||
|
indices = csr_tensor.indices
|
||||||
|
shape = csr_tensor.shape
|
||||||
|
|
||||||
|
output_shape_kept_dims = F.reduced_shape(shape, axis)
|
||||||
|
tile_scaling = F.tuple_div(shape, output_shape_kept_dims)
|
||||||
|
values_grad_dense = F.tile(F.reshape(dout, output_shape_kept_dims), tile_scaling)
|
||||||
|
values_grad = F.csr_gather(indptr, indices, values_grad_dense, shape)
|
||||||
|
return F.make_csr_tensor(indptr, indices, values_grad, shape), zeros_like(axis)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(_csr_ops.CSRMV)
|
||||||
|
def get_bprop_csr_mv(self):
|
||||||
|
"Back-propagation for CSRMV."
|
||||||
|
def bprop(csr_tensor, dense, out, dout):
|
||||||
|
indptr = F.csr_tensor_get_indptr(csr_tensor)
|
||||||
|
indices = F.csr_tensor_get_indices(csr_tensor)
|
||||||
|
values = F.csr_tensor_get_values(csr_tensor)
|
||||||
|
dense_shape = csr_tensor.shape
|
||||||
|
|
||||||
|
rows = F.csr2coo(indptr, indices.shape[0])
|
||||||
|
idx_dtype = rows.dtype
|
||||||
|
rows_transposed, cols_indexing = F.sort(indices.astype(mstype.float32))
|
||||||
|
rows_transposed = rows_transposed.astype(idx_dtype)
|
||||||
|
cols_transposed = rows[cols_indexing]
|
||||||
|
values_transposed = values[cols_indexing]
|
||||||
|
indptr_transposed = F.coo2csr(rows_transposed, dense_shape[1])
|
||||||
|
csr_tensor_transposed = F.make_csr_tensor(
|
||||||
|
indptr_transposed, cols_transposed, values_transposed, (dense_shape[1], dense_shape[0]))
|
||||||
|
|
||||||
|
dense_grad = F.csr_mv(csr_tensor_transposed, dout)
|
||||||
|
parts_a = F.gather(dout, rows, 0)
|
||||||
|
parts_b = F.gather(dense, indices, 0)
|
||||||
|
values_grad = F.reduce_sum(parts_a * parts_b, 1)
|
||||||
|
return F.make_csr_tensor(indptr, indices, values_grad, csr_tensor.shape), dense_grad
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(_csr_ops.CSRMul)
|
||||||
|
def get_bprop_csr_mul(self):
|
||||||
|
"Back-propagation for CSRMul."
|
||||||
|
def bprop(csr_tensor, dense, out, dout):
|
||||||
|
indptr = csr_tensor.indptr
|
||||||
|
indices = csr_tensor.indices
|
||||||
|
values = csr_tensor.values
|
||||||
|
shape = csr_tensor.shape
|
||||||
|
|
||||||
|
csr_tensor_grad_value = F.csr_mul(F.make_csr_tensor(indptr, indices, dout, shape), dense)
|
||||||
|
csr_tensor_grad = F.make_csr_tensor(indptr, indices, csr_tensor_grad_value, shape)
|
||||||
|
dense_grad_value = F.mul(dout, values)
|
||||||
|
dense_grad = F.make_csr_tensor(indptr, indices, dense_grad_value, shape)
|
||||||
|
if len(dense.shape) == 1 or dense.shape[0] == 1:
|
||||||
|
dense_grad = F.csr_reduce_sum(dense_grad, 0)
|
||||||
|
elif dense.shape[1] == 1:
|
||||||
|
dense_grad = F.csr_reduce_sum(dense_grad, 1)
|
||||||
|
else:
|
||||||
|
row = F.csr2coo(indptr, indices.shape[0])
|
||||||
|
coo_idx = P.Stack(-1)((row, indices))
|
||||||
|
dense_grad = F.tensor_scatter_update(zeros_like(dense), coo_idx, dense_grad_value)
|
||||||
|
return csr_tensor_grad, dense_grad
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(_csr_ops.CSR2COO)
|
||||||
|
def get_bprop_csr2coo(self):
|
||||||
|
def bprop(indptr, nnz, out, dout):
|
||||||
|
return zeros_like(dout)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
@bprop_getters.register(_csr_ops.COO2CSR)
|
||||||
|
def get_bprop_coo2csr(self):
|
||||||
|
def bprop(row_indices, height, out, dout):
|
||||||
|
return zeros_like(dout)
|
||||||
|
return bprop
|
||||||
|
|
|
@ -25,4 +25,7 @@ from .notequal import _notequal_akg
|
||||||
from .csr_reduce_sum import _csr_reduce_sum_akg
|
from .csr_reduce_sum import _csr_reduce_sum_akg
|
||||||
from .csr_mv import _csr_mv_akg
|
from .csr_mv import _csr_mv_akg
|
||||||
from .csr_mul import _csr_mul_akg
|
from .csr_mul import _csr_mul_akg
|
||||||
|
from .csr_gather import _csr_gather_akg
|
||||||
|
from .csr2coo import _csr2coo_akg
|
||||||
|
from .coo2csr import _coo2csr_akg
|
||||||
# Please insert op register in lexicographical order of the filename.
|
# Please insert op register in lexicographical order of the filename.
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""COO2CSR op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||||
|
|
||||||
|
coo2csr_op_info = AkgGpuRegOp("COO2CSR") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "row_indices") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(coo2csr_op_info)
|
||||||
|
def _coo2csr_akg():
|
||||||
|
"""COO2CSR AutoDiff register"""
|
||||||
|
return
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""CSR2COO op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||||
|
|
||||||
|
csr2coo_op_info = AkgGpuRegOp("CSR2COO") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "indptr") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(csr2coo_op_info)
|
||||||
|
def _csr2coo_akg():
|
||||||
|
"""CSR2COO AutoDiff register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""CSRGatherop"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
|
||||||
|
|
||||||
|
csr_gather_op_info = AkgGpuRegOp("CSRGather") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "indptr") \
|
||||||
|
.input(1, "indices") \
|
||||||
|
.input(2, "dense") \
|
||||||
|
.output(0, "output") \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(csr_gather_op_info)
|
||||||
|
def _csr_gather_akg():
|
||||||
|
"""CSRGather AutoDiff register"""
|
||||||
|
return
|
|
@ -22,9 +22,6 @@ csr_mv_op_info = AkgGpuRegOp("CSRMV") \
|
||||||
.input(2, "values") \
|
.input(2, "values") \
|
||||||
.input(4, "dense_tensor") \
|
.input(4, "dense_tensor") \
|
||||||
.output(0, "output") \
|
.output(0, "output") \
|
||||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.F32_Default, \
|
|
||||||
DataType.F32_Default, \
|
|
||||||
DataType.F32_Default) \
|
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, \
|
||||||
DataType.F32_Default, \
|
DataType.F32_Default, \
|
||||||
DataType.F32_Default) \
|
DataType.F32_Default) \
|
||||||
|
|
|
@ -346,4 +346,19 @@ def _add_nonetensor_tensor(x, y):
|
||||||
"""
|
"""
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
@_add_backward.register("CSRTensor", "CSRTensor")
|
||||||
|
def _add_csrtensor_csrtensor(x, y):
|
||||||
|
"""
|
||||||
|
Adds CSRTensor and CSRTensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (CSRTensor): x
|
||||||
|
y (CSRTensor): y
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSRTensor.
|
||||||
|
"""
|
||||||
|
return F.make_csr_tensor(x.indptr, x.indices, x.values + y.values, x.shape)
|
||||||
|
|
||||||
hyper_add = base.HyperMap(_add_backward)
|
hyper_add = base.HyperMap(_add_backward)
|
||||||
|
|
|
@ -58,6 +58,12 @@ def _ones_like_coo_tensor(x):
|
||||||
return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x))
|
return F.make_coo_tensor(F.coo_tensor_get_indices(x), values, F.coo_tensor_get_dense_shape(x))
|
||||||
|
|
||||||
|
|
||||||
|
@ones_like_leaf.register("CSRTensor")
|
||||||
|
def _ones_like_csr_tensor(x):
|
||||||
|
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||||
|
return F.make_csr_tensor(x.indptr, x.indices, ones_like(x.values), x.shape)
|
||||||
|
|
||||||
|
|
||||||
@ones_like_leaf.register("Function")
|
@ones_like_leaf.register("Function")
|
||||||
def _ones_like_func(x):
|
def _ones_like_func(x):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -58,6 +58,20 @@ def _zeros_like_tensor(x):
|
||||||
return F.zeros_like(x)
|
return F.zeros_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
@zeros_like_leaf.register("COOTensor")
|
||||||
|
def _zeros_like_coo_tensor(x):
|
||||||
|
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||||
|
values = F.zeros_like(x.values)
|
||||||
|
return F.make_coo_tensor(x.indices, values, x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
@zeros_like_leaf.register("CSRTensor")
|
||||||
|
def _zeros_like_csr_tensor(x):
|
||||||
|
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
|
||||||
|
values = F.zeros_like(x.values)
|
||||||
|
return F.make_csr_tensor(x.indptr, x.indices, values, x.shape)
|
||||||
|
|
||||||
|
|
||||||
@zeros_like_leaf.register("TypeType")
|
@zeros_like_leaf.register("TypeType")
|
||||||
def _zeros_like_type_type(x):
|
def _zeros_like_type_type(x):
|
||||||
"""Returns x because x is a type. This is usually used in backprop progress."""
|
"""Returns x because x is a type. This is usually used in backprop progress."""
|
||||||
|
|
|
@ -152,6 +152,9 @@ stack = P.Stack()
|
||||||
csr_mul = _csr_ops.CSRMul()
|
csr_mul = _csr_ops.CSRMul()
|
||||||
csr_mv = _csr_ops.CSRMV()
|
csr_mv = _csr_ops.CSRMV()
|
||||||
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
||||||
|
csr_gather = _csr_ops.CSRGather()
|
||||||
|
csr2coo = _csr_ops.CSR2COO()
|
||||||
|
coo2csr = _csr_ops.COO2CSR()
|
||||||
|
|
||||||
_select = P.Select()
|
_select = P.Select()
|
||||||
|
|
||||||
|
@ -576,6 +579,7 @@ not_in_dict = Primitive("not_in_dict")
|
||||||
mixed_precision_cast = Primitive("mixed_precision_cast")
|
mixed_precision_cast = Primitive("mixed_precision_cast")
|
||||||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||||
array_reduce = Primitive('array_reduce')
|
array_reduce = Primitive('array_reduce')
|
||||||
|
zeros = P.Zeros()
|
||||||
zeros_like = P.ZerosLike()
|
zeros_like = P.ZerosLike()
|
||||||
distribute = Primitive('distribute')
|
distribute = Primitive('distribute')
|
||||||
embed = Primitive('embed')
|
embed = Primitive('embed')
|
||||||
|
@ -670,6 +674,11 @@ tensor_operator_registry.register('log', log)
|
||||||
tensor_operator_registry.register('floor', floor)
|
tensor_operator_registry.register('floor', floor)
|
||||||
# support sparse tensor operators
|
# support sparse tensor operators
|
||||||
tensor_operator_registry.register('csr_mul', csr_mul)
|
tensor_operator_registry.register('csr_mul', csr_mul)
|
||||||
|
tensor_operator_registry.register('csr2coo', csr2coo)
|
||||||
|
tensor_operator_registry.register('coo2csr', coo2csr)
|
||||||
tensor_operator_registry.register('narrow', narrow)
|
tensor_operator_registry.register('narrow', narrow)
|
||||||
|
tensor_operator_registry.register('sort', sort)
|
||||||
|
tensor_operator_registry.register('zeros', zeros)
|
||||||
|
tensor_operator_registry.register('tensor_scatter_update', tensor_scatter_update)
|
||||||
__all__ = [name for name in dir() if name[0] != "_"]
|
__all__ = [name for name in dir() if name[0] != "_"]
|
||||||
__all__.remove('Primitive')
|
__all__.remove('Primitive')
|
||||||
|
|
|
@ -20,6 +20,9 @@ class CSRReduceSum(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
Reduces a dimension of a CSRTensor by summing all elements in the dimension.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
||||||
- **axis** (int) - The dimensions to reduce.
|
- **axis** (int) - The dimensions to reduce.
|
||||||
|
@ -64,6 +67,9 @@ class CSRMV(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Sparse matrix-vector multiplication.
|
Sparse matrix-vector multiplication.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
- **sparse_tensor** (CSRTensor) - A CSRTensor.
|
||||||
- **dense_tensor** (Tensor) - A dense Tensor.
|
- **dense_tensor** (Tensor) - A dense Tensor.
|
||||||
|
@ -109,6 +115,9 @@ class CSRMul(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Elemwise multiplication on a CSRTensor and a dense tensor.
|
Elemwise multiplication on a CSRTensor and a dense tensor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`.
|
The op outputs a 1-D dense tensor whose shape and values are the same as input `CSRTensor.values`.
|
||||||
If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor.
|
If expect a CSRTensor output, please use `*` directly, e.g. `x * y`, `x` or `y` can be CSRTensor.
|
||||||
|
@ -151,3 +160,129 @@ class CSRMul(PrimitiveWithInfer):
|
||||||
"""Initialize CSRMul"""
|
"""Initialize CSRMul"""
|
||||||
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
|
self.init_prim_io_names(inputs=['indptr', 'indices', 'values', 'dense_shape', 'dense_tensor'],
|
||||||
outputs=['output'])
|
outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
class CSRGather(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **indptr** (Tensor) - A Tensor.
|
||||||
|
- **indices** (Tensor) - A Tensor.
|
||||||
|
- **dense** (Tensor) - A Tensor.
|
||||||
|
- **sparse_shape** (tuple) - A tuple of integers.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the dtype is the same as `dense`, the first dimension is the same shape as `indices` and the remaining
|
||||||
|
dimensions are the same as ``dense[2:]``.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore import Tensor, ops
|
||||||
|
>>> from mindspore import dtype as mstype
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
... def __init__(self):
|
||||||
|
... super(Net, self).__init__()
|
||||||
|
... self.op = ops.CSRGather()
|
||||||
|
...
|
||||||
|
... def construct(self, indptr, indices, dense, sparse_shape):
|
||||||
|
... return self.op(indptr, indices, dense, sparse_shape)
|
||||||
|
>>> indptr = Tensor([0, 1, 2])
|
||||||
|
>>> indices = Tensor([0, 1])
|
||||||
|
>>> sparse_shape = (2, 4)
|
||||||
|
>>> dense = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||||
|
>>> out = Net()(indptr, indices, dense, sparse_shape)
|
||||||
|
>>> print(out)
|
||||||
|
[1. 1.]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize CSRGather"""
|
||||||
|
self.init_prim_io_names(inputs=['indptr', 'indices', 'dense', 'dense_shape'],
|
||||||
|
outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
class CSR2COO(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Converts the indptr of a CSRTensor to the row indices of a COOTensor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **indptr** (Tensor) - A Tensor.
|
||||||
|
- **nnz** (int) - Denotes the number of non-zero elements in the sparse tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the dtype is the same as `indptr` and has shape (`nnz`,).
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore import Tensor, ops
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
... def __init__(self):
|
||||||
|
... super(Net, self).__init__()
|
||||||
|
... self.op = ops.CSR2COO()
|
||||||
|
...
|
||||||
|
... def construct(self, indptr, nnz):
|
||||||
|
... return self.op(indptr, nnz)
|
||||||
|
>>> indptr = Tensor([0, 1, 2])
|
||||||
|
>>> out = Net()(indptr, 2)
|
||||||
|
>>> print(out)
|
||||||
|
[1 1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize CSR2COO"""
|
||||||
|
self.init_prim_io_names(inputs=['indptr', 'nnz'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
class COO2CSR(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Converts the row indices of a COOTensor to the indptr of a CSRTensor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **row_indices** (Tensor) - A Tensor.
|
||||||
|
- **height** (int) - the height of the first dimension of the sparse tensor.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, the dtype is the same as `row_indices` and has shape ('height' + 1,).
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``GPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore import Tensor, ops
|
||||||
|
>>> from mindspore import dtype as mstype
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
... def __init__(self):
|
||||||
|
... super(Net, self).__init__()
|
||||||
|
... self.op = ops.COO2CSR()
|
||||||
|
...
|
||||||
|
... def construct(self, row_indices, height):
|
||||||
|
... return self.op(row_indices, height)
|
||||||
|
>>> row_indices = Tensor([0, 1])
|
||||||
|
>>> out = Net()(row_indices, 2)
|
||||||
|
>>> print(out)
|
||||||
|
[0 1 2]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize COO2CSR"""
|
||||||
|
self.init_prim_io_names(inputs=['row_indices', 'height'], outputs=['output'])
|
||||||
|
|
|
@ -91,3 +91,39 @@ def test_coo_tensor_in_while():
|
||||||
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
|
||||||
assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
|
assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
|
||||||
assert out.shape == shape
|
assert out.shape == shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_coo_method():
|
||||||
|
"""
|
||||||
|
Feature: Test coo tensor methods.
|
||||||
|
Description: Test coo_tensor.to_csr(), coo_tensor.to_dense().
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
class COOToCSRNet(nn.Cell):
|
||||||
|
def construct(self, coo_tensor):
|
||||||
|
return coo_tensor.to_csr()
|
||||||
|
|
||||||
|
class COOToDenseNet(nn.Cell):
|
||||||
|
def construct(self, coo_tensor):
|
||||||
|
return coo_tensor.to_dense()
|
||||||
|
|
||||||
|
indices = Tensor([[1, 2], [0, 1]], dtype=mstype.int32)
|
||||||
|
values = Tensor([2, 1], dtype=mstype.float32)
|
||||||
|
shape = (3, 4)
|
||||||
|
coo_tensor = COOTensor(indices, values, shape)
|
||||||
|
|
||||||
|
to_csr_output = COOToCSRNet()(coo_tensor)
|
||||||
|
to_csr_expect_1 = np.array([0, 1, 2, 2], dtype=np.int32)
|
||||||
|
to_csr_expect_2 = np.array([1, 2], dtype=np.int32)
|
||||||
|
to_csr_expect_3 = np.array([1, 2], dtype=np.float32)
|
||||||
|
assert np.allclose(to_csr_output.indptr.asnumpy(), to_csr_expect_1)
|
||||||
|
assert np.allclose(to_csr_output.indices.asnumpy(), to_csr_expect_2)
|
||||||
|
assert np.allclose(to_csr_output.values.asnumpy(), to_csr_expect_3)
|
||||||
|
|
||||||
|
to_dense_output = COOToDenseNet()(coo_tensor)
|
||||||
|
to_dense_expect = np.array(
|
||||||
|
[[0., 1., 0., 0.], [0., 0., 2., 0.], [0., 0., 0., 0.]], dtype=np.float32)
|
||||||
|
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, CSRTensor, ms_function, nn, context
|
from mindspore import Tensor, CSRTensor, ms_function, nn, context, ops
|
||||||
from mindspore.ops.operations import _csr_ops
|
from mindspore.ops.operations import _csr_ops
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.train.serialization import export, load
|
from mindspore.train.serialization import export, load
|
||||||
|
@ -232,8 +232,8 @@ def test_csr_ops():
|
||||||
csr_reducesum = _csr_ops.CSRReduceSum()
|
csr_reducesum = _csr_ops.CSRReduceSum()
|
||||||
csrmv = _csr_ops.CSRMV()
|
csrmv = _csr_ops.CSRMV()
|
||||||
|
|
||||||
indptr = Tensor([0, 1, 2])
|
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||||
indices = Tensor([0, 1])
|
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||||
values = Tensor([2, 1], dtype=mstype.float32)
|
values = Tensor([2, 1], dtype=mstype.float32)
|
||||||
dense_shape = (2, 4)
|
dense_shape = (2, 4)
|
||||||
|
|
||||||
|
@ -331,8 +331,8 @@ def test_csrops_export_and_import_mindir():
|
||||||
sparse2 = dence_tensor * csr_tensor
|
sparse2 = dence_tensor * csr_tensor
|
||||||
return dense1, dense2, dense3, sparse1, sparse2
|
return dense1, dense2, dense3, sparse1, sparse2
|
||||||
|
|
||||||
indptr = Tensor([0, 1, 2])
|
indptr = Tensor([0, 1, 2], dtype=mstype.int32)
|
||||||
indices = Tensor([0, 1])
|
indices = Tensor([0, 1], dtype=mstype.int32)
|
||||||
values = Tensor([2, 1], dtype=mstype.float32)
|
values = Tensor([2, 1], dtype=mstype.float32)
|
||||||
shape = (2, 4)
|
shape = (2, 4)
|
||||||
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
dense_tensor = Tensor([[1., 1, 1, 1], [1, 1, 1, 1]], dtype=mstype.float32)
|
||||||
|
@ -428,3 +428,111 @@ def test_dtype_csr_tensor():
|
||||||
out2 = graph_test()
|
out2 = graph_test()
|
||||||
assert out1 in [mstype.float32]
|
assert out1 in [mstype.float32]
|
||||||
assert out2 in [mstype.float32]
|
assert out2 in [mstype.float32]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_csr_bprop():
|
||||||
|
"""
|
||||||
|
Feature: Test back-propagation with CSR-related Ops.
|
||||||
|
Description: Test CSRReduceSum, CSRMul, CSRMV, CSRTensor.to_coo(), CSRTensor.to_dense().
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
class CSRMulNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CSRMulNet, self).__init__()
|
||||||
|
self.op = _csr_ops.CSRMul()
|
||||||
|
|
||||||
|
def construct(self, csr_tensor, dense):
|
||||||
|
return self.op(csr_tensor, dense)
|
||||||
|
|
||||||
|
class CSRReduceSumNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CSRReduceSumNet, self).__init__()
|
||||||
|
self.op = _csr_ops.CSRReduceSum()
|
||||||
|
|
||||||
|
def construct(self, csr_tensor, axis):
|
||||||
|
return self.op(csr_tensor, axis)
|
||||||
|
|
||||||
|
class CSRMVNet(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(CSRMVNet, self).__init__()
|
||||||
|
self.op = _csr_ops.CSRMV()
|
||||||
|
|
||||||
|
def construct(self, csr_tensor, dense):
|
||||||
|
return self.op(csr_tensor, dense)
|
||||||
|
|
||||||
|
class BpropNet(nn.Cell):
|
||||||
|
def __init__(self, net):
|
||||||
|
super(BpropNet, self).__init__()
|
||||||
|
self.net = net
|
||||||
|
self.grad_op = ops.GradOperation(get_all=True)
|
||||||
|
|
||||||
|
def construct(self, *inputs):
|
||||||
|
return self.grad_op(self.net)(*inputs)
|
||||||
|
|
||||||
|
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||||
|
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||||
|
values = Tensor(np.arange(6), dtype=mstype.float32)
|
||||||
|
dense_shape = (3, 4)
|
||||||
|
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||||
|
|
||||||
|
csr_mv_arg = Tensor([[1], [2], [3], [4]], dtype=mstype.float32)
|
||||||
|
csr_mv_output_1, csr_mv_output_2 = BpropNet(CSRMVNet())(csr_tensor, csr_mv_arg)
|
||||||
|
csr_mv_expect_1 = np.array([4, 1, 2, 3, 2, 4], dtype=np.float32)
|
||||||
|
csr_mv_expect_2 = np.array([[1], [6], [3], [5]], dtype=np.float32)
|
||||||
|
assert np.allclose(csr_mv_output_1.values.asnumpy(), csr_mv_expect_1)
|
||||||
|
assert np.allclose(csr_mv_output_2.asnumpy(), csr_mv_expect_2)
|
||||||
|
|
||||||
|
csr_reduce_sum_output = BpropNet(CSRReduceSumNet())(csr_tensor, 1)
|
||||||
|
csr_reduce_sum_expect = np.ones(6, dtype=np.float32)
|
||||||
|
assert np.allclose(csr_reduce_sum_output[0].values.asnumpy(), csr_reduce_sum_expect)
|
||||||
|
|
||||||
|
csr_mul_arg_1 = Tensor([[1], [2], [3]], dtype=mstype.float32)
|
||||||
|
csr_mul_output_1_1, csr_mul_output_1_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_1)
|
||||||
|
csr_mul_expect_1_1 = np.array([1, 2, 2, 2, 3, 3], dtype=np.float32)
|
||||||
|
csr_mul_expect_1_2 = np.array([[0], [6], [9]], dtype=np.float32)
|
||||||
|
assert np.allclose(csr_mul_output_1_1.values.asnumpy(), csr_mul_expect_1_1)
|
||||||
|
assert np.allclose(csr_mul_output_1_2.asnumpy(), csr_mul_expect_1_2)
|
||||||
|
|
||||||
|
csr_mul_arg_2 = Tensor(np.arange(12).reshape(3, 4), dtype=mstype.float32)
|
||||||
|
csr_mul_output_2_1, csr_mul_output_2_2 = BpropNet(CSRMulNet())(csr_tensor, csr_mul_arg_2)
|
||||||
|
csr_mul_expect_2_1 = np.array([3, 4, 5, 6, 9, 11], dtype=np.float32)
|
||||||
|
csr_mul_expect_2_2 = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||||
|
assert np.allclose(csr_mul_output_2_1.values.asnumpy(), csr_mul_expect_2_1)
|
||||||
|
assert np.allclose(csr_mul_output_2_2.asnumpy(), csr_mul_expect_2_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_csr_method():
|
||||||
|
"""
|
||||||
|
Feature: Test csr tensor methods.
|
||||||
|
Description: Test csr_tensor.to_coo(), csr_tensor.to_dense().
|
||||||
|
Expectation: Success.
|
||||||
|
"""
|
||||||
|
class CSRToCOONet(nn.Cell):
|
||||||
|
def construct(self, csr_tensor):
|
||||||
|
return csr_tensor.to_coo()
|
||||||
|
|
||||||
|
class CSRToDenseNet(nn.Cell):
|
||||||
|
def construct(self, csr_tensor):
|
||||||
|
return csr_tensor.to_dense()
|
||||||
|
|
||||||
|
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||||
|
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||||
|
values = Tensor(np.arange(6), dtype=mstype.float32)
|
||||||
|
dense_shape = (3, 4)
|
||||||
|
csr_tensor = CSRTensor(indptr, indices, values, dense_shape)
|
||||||
|
|
||||||
|
to_coo_output = CSRToCOONet()(csr_tensor)
|
||||||
|
to_coo_expect_1 = np.array([[0, 3], [1, 0], [1, 1], [1, 2], [2, 1], [2, 3]], dtype=np.int32)
|
||||||
|
to_coo_expect_2 = np.arange(6).astype(np.float32)
|
||||||
|
assert np.allclose(to_coo_output.indices.asnumpy(), to_coo_expect_1)
|
||||||
|
assert np.allclose(to_coo_output.values.asnumpy(), to_coo_expect_2)
|
||||||
|
|
||||||
|
to_dense_output = CSRToDenseNet()(csr_tensor)
|
||||||
|
to_dense_expect = np.array([[0, 0, 0, 0], [1, 2, 3, 0], [0, 4, 0, 5]], np.float32)
|
||||||
|
assert np.allclose(to_dense_output.asnumpy(), to_dense_expect)
|
||||||
|
|
Loading…
Reference in New Issue