!39833 move all sparse infer functions to core/ops/

Merge pull request !39833 from 杨林枫/prim_infer_refactor
This commit is contained in:
i-robot 2022-08-09 04:14:04 +00:00 committed by Gitee
commit 38065a1a9c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
38 changed files with 1783 additions and 596 deletions

View File

@ -129,30 +129,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
@ -163,18 +139,6 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -34,66 +34,10 @@
namespace { namespace {
constexpr auto kRankSize = "rank_size"; constexpr auto kRankSize = "rank_size";
inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
constexpr auto kCSRMulBatchPos = 2;
int dlen = mindspore::SizeToInt(sparse_shp.size()) - mindspore::SizeToInt(dense_shp.size());
if (dlen < 0) {
MS_EXCEPTION(mindspore::ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
<< "but sparse tensor has " << sparse_shp.size() << " dimensions, "
<< "and dense tensor has " << dense_shp.size() << " dimensions, ";
}
for (int i = 0; i < dlen; i++) {
(void)dense_shp.insert(dense_shp.begin(), 1);
}
if (sparse_shp.size() != dense_shp.size()) {
MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
}
if (sparse_shp.size() < 1) {
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
}
for (size_t i = 0; i < sparse_shp.size(); i++) {
auto s = sparse_shp[i];
auto d = dense_shp[i];
if (i < kCSRMulBatchPos) {
if (d != s && d != 1) {
MS_EXCEPTION(mindspore::ValueError) << "Dense shape cannot broadcast to sparse shape.";
}
} else {
if (d != s) {
MS_EXCEPTION(mindspore::ValueError)
<< "Currently, sparse shape and dense shape must equal in feature dimensions.";
}
}
}
}
inline void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
if (shape_size != expected_dim) {
MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim
<< "-dimensional tensor, but got a " << shape_size << "-dimensional tensor.";
}
}
inline void CheckSparseIndicesDtype(const mindspore::TypePtr data_type, const std::string &arg_name) {
if (!(data_type->equal(mindspore::kInt16) || data_type->equal(mindspore::kInt32) ||
data_type->equal(mindspore::kInt64))) {
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
<< data_type->ToString() << ".";
}
}
inline void CheckSparseIndicesDtypeInt32(const mindspore::TypePtr data_type, const std::string &arg_name) {
if (!data_type->equal(mindspore::kInt32)) {
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
<< data_type->ToString() << ".";
}
}
} // namespace } // namespace
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
constexpr auto kCSRDenseShape = "dense_shape";
constexpr auto kCSRAxis = "axis";
constexpr auto kCSRAvgRows = "csr_avg_rows";
constexpr auto kIsCSR = "is_csr";
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// An object of a subclass of AbstractBase // An object of a subclass of AbstractBase
@ -343,433 +287,6 @@ AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const Primitive
return args_spec_list[0]; return args_spec_list[0];
} }
AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeThree);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexTwo);
auto indices_dtype = indices->element()->BuildType();
CheckSparseIndicesDtype(indices_dtype, "Indices");
auto indices_shp = indices->shape()->shape();
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
auto values_shp = values->shape()->shape();
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
<< "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero
<< "]`: " << values_shp[kIndexZero];
}
constexpr int64_t kDimTwo = 2;
if (indices_shp[kIndexOne] != kDimTwo) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
<< indices_shp[kIndexOne];
}
for (const auto &elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got "
<< elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
auto min_elem = *std::min_element(std::begin(shp), std::end(shp));
if (min_elem <= 0) {
MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem
<< "int it";
}
ShapeVector dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
<< indices_shp[kIndexOne];
}
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem <= 0) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
<< dense_shape_value->ToString();
}
}
AbstractBasePtrList element_list{indices, values, dense_shape};
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
}
AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
return sparse_tensor->values();
}
AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
return sparse_tensor->indices();
}
AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->shape());
return sparse_tensor->shape();
}
ShapeVector ConvertToShapeVector(const AbstractTuplePtr &shape) {
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
ShapeVector shape_vec;
(void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
[](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
return shape_vec;
}
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a sparse tensor and a dense tensor.
constexpr auto kCSRElementwiseInputsNum = 5;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRElementwiseInputsNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
auto dense_shape = dense->shape()->shape();
CheckSparseShape(sparse_shape, dense_shape);
auto ret = values->Broaden();
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[0] / dense_shape[0];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
constexpr auto kCSRMVInputsNum = 5;
constexpr auto kCSRMVShapeSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRMVInputsNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
ShapeVector dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)"
<< ", but its current shape is: "
<< "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ").";
}
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a sparse tensor and an axis.
constexpr auto kCSRReduceSumInputsNum = 5;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRReduceSumInputsNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(axis);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
ShapeVector out_shape = sparse_shape;
MS_EXCEPTION_IF_NULL(axis->BuildValue());
if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) {
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
int64_t dim = static_cast<int64_t>(sparse_shape.size());
if (axis_value != 1 && axis_value != 1 - dim) {
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value
<< "and `1- dim`: " << 1 - dim;
}
if (axis_value < 0) {
axis_value += dim;
}
out_shape[LongToSize(axis_value)] = 1;
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
} else {
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
<< axis->BuildType()->ToString();
}
MS_EXCEPTION_IF_NULL(values->element());
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[0] / sparse_shape[0];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
constexpr size_t csr_row_num = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeFour);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(dense);
MS_EXCEPTION_IF_NULL(sparse_shape);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto nnz_vec = indices->shape()->shape();
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indices->shape());
ShapeVector out_shape = indices->shape()->shape();
MS_EXCEPTION_IF_NULL(dense->shape());
ShapeVector dense_shape = dense->shape()->shape();
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
out_shape.push_back(dense_shape[i]);
}
MS_EXCEPTION_IF_NULL(dense->element());
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(nnz);
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
ShapeVector out_shape;
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
out_shape.push_back(nnz_value);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
}
MS_EXCEPTION_IF_NULL(indptr->shape());
auto num_rows = indptr->shape()->shape()[0] - 1;
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indptr->element());
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(row_indices);
MS_EXCEPTION_IF_NULL(height);
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
MS_EXCEPTION_IF_NULL(height->BuildValue());
ShapeVector out_shape;
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
int64_t height_value = GetValue<int64_t>(height->BuildValue());
out_shape.push_back(height_value + 1);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
}
MS_EXCEPTION_IF_NULL(row_indices->element());
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
return ret;
}
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeFour);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
auto indptr_dtype = indptr->element()->BuildType();
auto indices_dtype = indices->element()->BuildType();
CheckSparseIndicesDtype(indptr_dtype, "indptr");
CheckSparseIndicesDtype(indices_dtype, "indices");
auto indptr_shp = indptr->shape()->shape();
CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr");
auto indices_shp = indices->shape()->shape();
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
auto values_shp = values->shape()->shape();
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: "
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
}
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto shp = shape_value->value();
ShapeVector shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
if (values_shp.size() + 1 != shape_vec.size()) {
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
<< "Values' dimension: " << values_shp.size()
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
}
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
}
size_t shape_size = 1;
auto shape_types = shape->ElementsType();
for (size_t i = 0; i < shape_vec.size(); ++i) {
if (shape_vec[i] <= 0) {
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
}
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape.";
}
if (!shape_types[i]->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString();
}
shape_size *= LongToSize(shape_vec[i]);
}
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
<< " non-zero values.";
}
AbstractBasePtrList element_list{indptr, indices, values, shape};
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
}
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
return CheckArg<T>(op_name, args_spec_list, 0);
}
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->values());
return csr_tensor->values();
}
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
return csr_tensor->indptr();
}
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
return csr_tensor->indices();
}
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->shape());
return csr_tensor->shape();
}
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
@ -1008,48 +525,5 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
AbstractBasePtrList rets = {add1, add0, sub0}; AbstractBasePtrList rets = {add1, add0, sub0};
return std::make_shared<AbstractTuple>(rets); return std::make_shared<AbstractTuple>(rets);
} }
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a sparse tensor and a dense tensor.
constexpr auto kCSRMMInputsNum = 5;
constexpr auto kCSRMMShapeSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRMMInputsNum);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
auto dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
<< " and csr tensor's shape[1] is " << sparse_shape[kIndexOne];
}
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -318,31 +318,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimDebug, R{InferImplDebug, nullptr, true}}, {prim::kPrimDebug, R{InferImplDebug, nullptr, true}},
// Dynamic shape testing // Dynamic shape testing
{prim::kPrimGpuConvertToDynamicShape, R{InferImplGpuConvertToDynamicShape, nullptr, true}}, {prim::kPrimGpuConvertToDynamicShape, R{InferImplGpuConvertToDynamicShape, nullptr, true}},
// COOTensor
{prim::kPrimMakeCOOTensor, R{InferImplMakeCOOTensor, nullptr, true}},
{prim::kPrimCOOTensorGetValues, R{InferImplCOOTensorGetValues, nullptr, true}},
{prim::kPrimCOOTensorGetIndices, R{InferImplCOOTensorGetIndices, nullptr, true}},
{prim::kPrimCOOTensorGetDenseShape, R{InferImplCOOTensorGetDenseShape, nullptr, true}},
// RowTensor // RowTensor
{prim::kPrimMakeRowTensor, R{InferImplMakeRowTensor, nullptr, true}}, {prim::kPrimMakeRowTensor, R{InferImplMakeRowTensor, nullptr, true}},
{prim::kPrimRowTensorGetValues, R{InferImplRowTensorGetValues, nullptr, true}}, {prim::kPrimRowTensorGetValues, R{InferImplRowTensorGetValues, nullptr, true}},
{prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}}, {prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}},
{prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}}, {prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}},
{prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}}, {prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}},
// CSRTensor
{prim::kPrimMakeCSRTensor, R{InferImplMakeCSRTensor, nullptr, true}},
{prim::kPrimCSRTensorGetValues, R{InferImplCSRTensorGetValues, nullptr, true}},
{prim::kPrimCSRTensorGetIndptr, R{InferImplCSRTensorGetIndptr, nullptr, true}},
{prim::kPrimCSRTensorGetIndices, R{InferImplCSRTensorGetIndices, nullptr, true}},
{prim::kPrimCSRTensorGetDenseShape, R{InferImplCSRTensorGetDenseShape, nullptr, true}},
{prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}},
{prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}},
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
{prim::kPrimCSRMM, R{InferImplCSRMM, nullptr, true}},
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
// Comm Ops // Comm Ops
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}}, {prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}}, {prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/coo_tensor_get_indices.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(coo_tensor->indices());
return coo_tensor->indices();
}
MIND_API_OPERATOR_IMPL(COOTensorGetIndices, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetIndices, prim::kPrimCOOTensorGetIndices, COOTensorGetIndicesInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCOOTensorGetIndices = "COOTensorGetIndices";
class MIND_API COOTensorGetIndices : public BaseOperator {
public:
MIND_API_BASE_MEMBER(COOTensorGetIndices);
/// \brief Constructor.
COOTensorGetIndices() : BaseOperator(kNameCOOTensorGetIndices) {}
};
abstract::AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/coo_tensor_get_shape.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(coo_tensor->shape());
return coo_tensor->shape();
}
MIND_API_OPERATOR_IMPL(COOTensorGetShape, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetShape, prim::kPrimCOOTensorGetDenseShape, COOTensorGetShapeInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCOOTensorGetShape = "COOTensorGetShape";
class MIND_API COOTensorGetShape : public BaseOperator {
public:
MIND_API_BASE_MEMBER(COOTensorGetShape);
/// \brief Constructor.
COOTensorGetShape() : BaseOperator(kNameCOOTensorGetShape) {}
};
abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/coo_tensor_get_values.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(coo_tensor->values());
return coo_tensor->values();
}
MIND_API_OPERATOR_IMPL(COOTensorGetValues, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetValues, prim::kPrimCOOTensorGetValues, COOTensorGetValuesInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCOOTensorGetValues = "COOTensorGetValues";
class MIND_API COOTensorGetValues : public BaseOperator {
public:
MIND_API_BASE_MEMBER(COOTensorGetValues);
/// \brief Constructor.
COOTensorGetValues() : BaseOperator(kNameCOOTensorGetValues) {}
};
abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/coo_to_csr.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractScalar;
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRArgsSize);
auto row_indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
auto height = abstract::CheckArg<AbstractScalar>(op_name, input_args, 1);
MS_EXCEPTION_IF_NULL(row_indices);
MS_EXCEPTION_IF_NULL(height);
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
MS_EXCEPTION_IF_NULL(height->BuildValue());
ShapeVector out_shape;
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
int64_t height_value = GetValue<int64_t>(height->BuildValue());
out_shape.push_back(height_value + 1);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
}
MS_EXCEPTION_IF_NULL(row_indices->element());
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
return ret;
}
MIND_API_OPERATOR_IMPL(COO2CSR, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(COO2CSR, prim::kPrimCOO2CSR, COO2CSRInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_COO_TO_CSR
#define MINDSPORE_CORE_OPS_COO_TO_CSR
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCOO2CSR = "COO2CSR";
/// \brief Converts the row indices of a COOTensor to the indptr of a CSRTensor.
class MIND_API COO2CSR : public BaseOperator {
public:
MIND_API_BASE_MEMBER(COO2CSR);
/// \brief Constructor.
COO2CSR() : BaseOperator(kNameCOO2CSR) { InitIOName({"row_indices", "height"}, {"output"}); }
};
abstract::AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_COO_TO_CSR

View File

@ -0,0 +1,68 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_elementwise.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: a sparse tensor and a dense tensor.
constexpr auto kCSRElementwiseInputsNum = 5;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRElementwiseInputsNum);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
auto dense_shape = dense->shape()->shape();
CheckSparseShape(sparse_shape, dense_shape);
auto ret = values->Broaden();
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[0] / dense_shape[0];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
MIND_API_OPERATOR_IMPL(CSRMul, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMul, prim::kPrimCSRMul, CSRElementWiseInfer, nullptr, true);
MIND_API_OPERATOR_IMPL(CSRDiv, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRDiv, prim::kPrimCSRDiv, CSRElementWiseInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_ELEMENTWISE
#define MINDSPORE_CORE_OPS_CSR_ELEMENTWISE
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRMul = "CSRMul";
constexpr auto kNameCSRDiv = "CSRDiv";
/// \brief CSRTensor elementwise operation.
class MIND_API CSRMul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRMul);
/// \brief Constructor.
CSRMul() : BaseOperator(kNameCSRMul) {
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
}
};
class MIND_API CSRDiv : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRDiv);
/// \brief Constructor.
CSRDiv() : BaseOperator(kNameCSRDiv) {
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
}
};
abstract::AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_ELEMENTWISE

View File

@ -0,0 +1,70 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_gather.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
constexpr size_t csr_row_num = 2;
const std::string op_name = primitive->name();
abstract::CheckArgsSize(op_name, input_args, kSizeFour);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexZero);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexOne);
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexTwo);
auto sparse_shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, kIndexThree);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(dense);
MS_EXCEPTION_IF_NULL(sparse_shape);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto nnz_vec = indices->shape()->shape();
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indices->shape());
ShapeVector out_shape = indices->shape()->shape();
MS_EXCEPTION_IF_NULL(dense->shape());
ShapeVector dense_shape = dense->shape()->shape();
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
out_shape.push_back(dense_shape[i]);
}
MS_EXCEPTION_IF_NULL(dense->element());
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
return ret;
}
MIND_API_OPERATOR_IMPL(CSRGather, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRGather, prim::kPrimCSRGather, CSRGatherInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_GATHER
#define MINDSPORE_CORE_OPS_CSR_GATHER
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRGather = "CSRGather";
/// \brief Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
class MIND_API CSRGather : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRGather);
/// \brief Constructor.
CSRGather() : BaseOperator(kNameCSRGather) { InitIOName({"indptr", "indices", "dense", "dense_shape"}, {"output"}); }
};
abstract::AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_GATHER

View File

@ -0,0 +1,78 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_mm.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: a sparse tensor and a dense tensor.
constexpr auto kCSRMMInputsNum = 5;
constexpr auto kCSRMMShapeSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRMMInputsNum);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
auto dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
<< " and csr tensor's shape[1] is " << sparse_shape[kIndexOne];
}
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
MIND_API_OPERATOR_IMPL(CSRMM, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMM, prim::kPrimCSRMM, CSRMMInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_MM
#define MINDSPORE_CORE_OPS_CSR_MM
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRMM = "CSRMM";
/// \brief Sparse matrix-matrix multiplication.
class MIND_API CSRMM : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRMM);
/// \brief Constructor.
CSRMM() : BaseOperator(kNameCSRMM) {
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
}
};
abstract::AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_MM

View File

@ -0,0 +1,77 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_mv.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
constexpr auto kCSRMVInputsNum = 5;
constexpr auto kCSRMVShapeSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRMVInputsNum);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(dense);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
ShapeVector dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
}
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)"
<< ", but its current shape is: "
<< "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ").";
}
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
MIND_API_OPERATOR_IMPL(CSRMV, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMV, prim::kPrimCSRMV, CSRMVInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_MV
#define MINDSPORE_CORE_OPS_CSR_MV
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRMV = "CSRMV";
/// \brief Sparse matrix-vector multiplication.
class MIND_API CSRMV : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRMV);
/// \brief Constructor.
CSRMV() : BaseOperator(kNameCSRMV) {
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
}
};
abstract::AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_MV

View File

@ -0,0 +1,85 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_reducesum.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractScalar;
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: a sparse tensor and an axis.
constexpr auto kCSRReduceSumInputsNum = 5;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRReduceSumInputsNum);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
auto axis = abstract::CheckArg<AbstractScalar>(op_name, input_args, 4);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(indices);
MS_EXCEPTION_IF_NULL(values);
MS_EXCEPTION_IF_NULL(shape);
MS_EXCEPTION_IF_NULL(axis);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
ShapeVector sparse_shape = ConvertToShapeVector(shape);
ShapeVector out_shape = sparse_shape;
MS_EXCEPTION_IF_NULL(axis->BuildValue());
if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) {
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
int64_t dim = static_cast<int64_t>(sparse_shape.size());
if (axis_value != 1 && axis_value != 1 - dim) {
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value
<< "and `1- dim`: " << 1 - dim << ".";
}
if (axis_value < 0) {
axis_value += dim;
}
out_shape[LongToSize(axis_value)] = 1;
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
} else {
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
<< axis->BuildType()->ToString() << ".";
}
MS_EXCEPTION_IF_NULL(values->element());
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
// SetAttr
auto nnz_vec = indices->shape()->shape();
auto csr_avg_rows = nnz_vec[0] / sparse_shape[0];
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
return ret;
}
MIND_API_OPERATOR_IMPL(CSRReduceSum, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRReduceSum, prim::kPrimCSRReduceSum, CSRReduceSumInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_REDUCESUM
#define MINDSPORE_CORE_OPS_CSR_REDUCESUM
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRReduceSum = "CSRReduceSum";
/// \brief CSRTensor reducesum.
class MIND_API CSRReduceSum : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRReduceSum);
/// \brief Constructor.
CSRReduceSum() : BaseOperator(kNameCSRReduceSum) {
InitIOName({"indptr", "indices", "values", "dense_shape", "axis"}, {"output"});
}
};
abstract::AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_REDUCESUM

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_tensor_get_indices.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
return csr_tensor->indices();
}
MIND_API_OPERATOR_IMPL(CSRTensorGetIndices, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndices, prim::kPrimCSRTensorGetIndices, CSRTensorGetIndicesInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRTensorGetIndices = "CSRTensorGetIndices";
/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor.
class MIND_API CSRTensorGetIndices : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRTensorGetIndices);
/// \brief Constructor.
CSRTensorGetIndices() : BaseOperator(kNameCSRTensorGetIndices) {}
};
abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_tensor_get_indptr.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
return csr_tensor->indptr();
}
MIND_API_OPERATOR_IMPL(CSRTensorGetIndptr, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndptr, prim::kPrimCSRTensorGetIndptr, CSRTensorGetIndptrInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRTensorGetIndptr = "CSRTensorGetIndptr";
/// \brief CSRTensorGetIndptr op is used to get indptr in CSRTensor.
class MIND_API CSRTensorGetIndptr : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRTensorGetIndptr);
/// \brief Constructor.
CSRTensorGetIndptr() : BaseOperator(kNameCSRTensorGetIndptr) {}
};
abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_tensor_get_shape.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->shape());
return csr_tensor->shape();
}
MIND_API_OPERATOR_IMPL(CSRTensorGetShape, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetShape, prim::kPrimCSRTensorGetDenseShape, CSRTensorGetShapeInfer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRTensorGetShape = "CSRTensorGetShape";
class MIND_API CSRTensorGetShape : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRTensorGetShape);
/// \brief Constructor.
CSRTensorGetShape() : BaseOperator(kNameCSRTensorGetShape) {}
};
abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_tensor_get_values.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
MS_EXCEPTION_IF_NULL(csr_tensor->values());
return csr_tensor->values();
}
MIND_API_OPERATOR_IMPL(CSRTensorGetValues, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetValues, prim::kPrimCSRTensorGetValues, CSRTensorGetValuesInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSRTensorGetValues = "CSRTensorGetValues";
/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor.
class MIND_API CSRTensorGetValues : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSRTensorGetValues);
/// \brief Constructor.
CSRTensorGetValues() : BaseOperator(kNameCSRTensorGetValues) {}
};
abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_

View File

@ -0,0 +1,67 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/csr_to_coo.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractScalar;
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
constexpr auto kCSRArgsSize = 2;
const std::string op_name = primitive->name();
CheckArgsSize(op_name, input_args, kCSRArgsSize);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
auto nnz = abstract::CheckArg<AbstractScalar>(op_name, input_args, 1);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(nnz);
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
ShapeVector out_shape;
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
out_shape.push_back(nnz_value);
} else {
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
}
MS_EXCEPTION_IF_NULL(indptr->shape());
auto num_rows = indptr->shape()->shape()[0] - 1;
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
primitive->set_attr(kIsCSR, MakeValue(true));
MS_EXCEPTION_IF_NULL(indptr->element());
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
return ret;
}
MIND_API_OPERATOR_IMPL(CSR2COO, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(CSR2COO, prim::kPrimCSR2COO, CSR2COOInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_CSR_TO_COO
#define MINDSPORE_CORE_OPS_CSR_TO_COO
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameCSR2COO = "CSR2COO";
/// \brief Converts the indptr of a CSRTensor to the row indices of a COOTensor.
class MIND_API CSR2COO : public BaseOperator {
public:
MIND_API_BASE_MEMBER(CSR2COO);
/// \brief Constructor.
CSR2COO() : BaseOperator(kNameCSR2COO) { InitIOName({"indptr", "nnz"}, {"output"}); }
};
abstract::AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_CSR_TO_COO

View File

@ -0,0 +1,104 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include "ops/make_cootensor.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeThree);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
auto values = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
auto dense_shape = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexTwo);
auto indices_dtype = indices->element()->BuildType();
CheckSparseIndicesDtype(indices_dtype, "Indices");
auto indices_shp = indices->shape()->shape();
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
auto values_shp = values->shape()->shape();
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
<< "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero
<< "]`: " << values_shp[kIndexZero];
}
constexpr int64_t kDimTwo = 2;
if (indices_shp[kIndexOne] != kDimTwo) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
<< indices_shp[kIndexOne];
}
for (const auto &elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got "
<< elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
auto min_elem = *std::min_element(std::begin(shp), std::end(shp));
if (min_elem <= 0) {
MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem
<< "int it";
}
ShapeVector dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
<< indices_shp[kIndexOne];
}
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem <= 0) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
<< dense_shape_value->ToString();
}
}
AbstractBasePtrList element_list{indices, values, dense_shape};
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
}
MIND_API_OPERATOR_IMPL(MakeCOOTensor, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeCOOTensor, prim::kPrimMakeCOOTensor, MakeCOOTensorInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_
#define MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMakeCOOTensor = "MakeCOOTensor";
class MIND_API MakeCOOTensor : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MakeCOOTensor);
/// \brief Constructor.
MakeCOOTensor() : BaseOperator(kNameMakeCOOTensor) {}
};
abstract::AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_

View File

@ -0,0 +1,104 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <memory>
#include "ops/make_csrtensor.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
// Inputs: three tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kSizeFour);
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
auto indices = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
auto values = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
auto shape = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
auto indptr_dtype = indptr->element()->BuildType();
auto indices_dtype = indices->element()->BuildType();
CheckSparseIndicesDtype(indptr_dtype, "indptr");
CheckSparseIndicesDtype(indices_dtype, "indices");
auto indptr_shp = indptr->shape()->shape();
CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr");
auto indices_shp = indices->shape()->shape();
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
auto values_shp = values->shape()->shape();
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: "
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
}
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto shp = shape_value->value();
ShapeVector shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
if (values_shp.size() + 1 != shape_vec.size()) {
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
<< "Values' dimension: " << values_shp.size()
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
}
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
}
size_t shape_size = 1;
auto shape_types = shape->ElementsType();
for (size_t i = 0; i < shape_vec.size(); ++i) {
if (shape_vec[i] <= 0) {
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
}
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape.";
}
if (!shape_types[i]->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString();
}
shape_size *= LongToSize(shape_vec[i]);
}
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
<< " non-zero values.";
}
std::vector<abstract::AbstractBasePtr> element_list{indptr, indices, values, shape};
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
}
MIND_API_OPERATOR_IMPL(MakeCSRTensor, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(MakeCSRTensor, prim::kPrimMakeCSRTensor, MakeCSRTensorInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_
#define MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_
#include <vector>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameMakeCSRTensor = "MakeCSRTensor";
/// \brief MakeCSRTensor op is used to construct CSRTensor.
class MIND_API MakeCSRTensor : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MakeCSRTensor);
/// \brief Constructor.
MakeCSRTensor() : BaseOperator(kNameMakeCSRTensor) {}
};
abstract::AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_

View File

@ -466,5 +466,99 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP
ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) { ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
return EvalShapeTensorValue(prim, args, true); return EvalShapeTensorValue(prim, args, true);
} }
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
constexpr auto kCSRMulBatchPos = 2;
int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
if (dlen < 0) {
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
<< "but sparse tensor has " << sparse_shp.size() << " dimensions, "
<< "and dense tensor has " << dense_shp.size() << " dimensions. ";
}
for (int i = 0; i < dlen; i++) {
(void)dense_shp.insert(dense_shp.begin(), 1);
}
if (sparse_shp.size() != dense_shp.size()) {
MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
}
if (sparse_shp.size() < 1) {
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
}
for (size_t i = 0; i < sparse_shp.size(); i++) {
auto s = sparse_shp[i];
auto d = dense_shp[i];
if (i < kCSRMulBatchPos) {
if (d != s && d != 1) {
MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
}
} else {
if (d != s) {
MS_EXCEPTION(ValueError) << "Currently, sparse shape and dense shape must equal in feature dimensions.";
}
}
}
}
void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
if (shape_size != expected_dim) {
MS_EXCEPTION(ValueError) << arg_name << " must be a " << expected_dim << "-dimensional tensor, but got a "
<< shape_size << "-dimensional tensor.";
}
}
void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name) {
if (!(data_type->equal(kInt16) || data_type->equal(kInt32) || data_type->equal(kInt64))) {
MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
<< data_type->ToString() << ".";
}
}
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name) {
if (!data_type->equal(kInt32)) {
MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
<< data_type->ToString() << ".";
}
}
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape) {
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
ShapeVector shape_vec;
(void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
[](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
return shape_vec;
}
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
constexpr size_t kSizeExpect = 1;
if (args_spec_list.size() != kSizeExpect) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the number of input should be " << kSizeExpect
<< ", but got " << args_spec_list.size() << ".";
}
constexpr size_t kIndex = 0;
auto abs = args_spec_list[kIndex];
MS_EXCEPTION_IF_NULL(abs);
// To avoid AbstractSparseTensors being generalized to AbstractTuple.
if (dyn_cast<T>(abs) == nullptr) {
auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
if (abs_tuple != nullptr) {
return std::make_shared<T>(abs_tuple->elements());
}
} else if (dyn_cast<T>(abs) != nullptr) {
return dyn_cast<T>(abs);
}
MS_EXCEPTION(TypeError) << "For \'" << primitive->name() << "\', input[" << kIndex
<< "] should be AbstractSparseTensor or AbstractTuple, but got "
<< abs->BuildType()->ToString() << ".";
}
template std::shared_ptr<abstract::AbstractCSRTensor> InferSparseAttr(const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template std::shared_ptr<abstract::AbstractCOOTensor> InferSparseAttr(const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -80,5 +80,23 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP
// Infer shape value of compute-shape op that could change the dim value, e.g. Mul, Add, Sub // Infer shape value of compute-shape op that could change the dim value, e.g. Mul, Add, Sub
// Do not support op with multiple outputs for now // Do not support op with multiple outputs for now
ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args); ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args);
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp);
void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name);
void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name);
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
template <typename T>
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
constexpr auto kCSRAvgRows = "csr_avg_rows";
constexpr auto kIsCSR = "is_csr";
constexpr auto kCSRDenseShape = "dense_shape";
constexpr auto kCSRAxis = "axis";
} // namespace mindspore::ops } // namespace mindspore::ops
#endif // MINDSPORE_CORE_OPS_OP_UTILS_H #endif // MINDSPORE_CORE_OPS_OP_UTILS_H

View File

@ -46,21 +46,6 @@ constexpr size_t kAlphaIndex = 10;
constexpr size_t kBetaIndex = 11; constexpr size_t kBetaIndex = 11;
constexpr int64_t kDefaultRank = 2; constexpr int64_t kDefaultRank = 2;
constexpr int64_t kBatchedRank = 3; constexpr int64_t kBatchedRank = 3;
inline void CheckSparseShape(const size_t sparse_shape_size, const size_t expected_dim, const std::string &arg_name) {
if (sparse_shape_size != expected_dim) {
MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim
<< "-dimensional tensor, but got a " << sparse_shape_size
<< "-dimensional tensor.";
}
}
inline void CheckSparseIndicesDtype(const mindspore::TypePtr dtype, const std::string &arg_name) {
if (!(dtype->equal(mindspore::kInt16) || dtype->equal(mindspore::kInt32) || dtype->equal(mindspore::kInt64))) {
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
<< dtype->ToString() << ".";
}
}
} // namespace } // namespace
void SparseMatrixAdd::set_dense_shape(const std::vector<int64_t> &shape) { void SparseMatrixAdd::set_dense_shape(const std::vector<int64_t> &shape) {
(void)this->AddAttr(kDenseShape, api::MakeValue(shape)); (void)this->AddAttr(kDenseShape, api::MakeValue(shape));