!39833 move all sparse infer functions to core/ops/
Merge pull request !39833 from 杨林枫/prim_infer_refactor
This commit is contained in:
commit
38065a1a9c
|
@ -129,30 +129,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -163,18 +139,6 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -34,66 +34,10 @@
|
|||
|
||||
namespace {
|
||||
constexpr auto kRankSize = "rank_size";
|
||||
inline void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
||||
constexpr auto kCSRMulBatchPos = 2;
|
||||
int dlen = mindspore::SizeToInt(sparse_shp.size()) - mindspore::SizeToInt(dense_shp.size());
|
||||
if (dlen < 0) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
|
||||
<< "but sparse tensor has " << sparse_shp.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shp.size() << " dimensions, ";
|
||||
}
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)dense_shp.insert(dense_shp.begin(), 1);
|
||||
}
|
||||
if (sparse_shp.size() != dense_shp.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
|
||||
}
|
||||
if (sparse_shp.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
|
||||
}
|
||||
for (size_t i = 0; i < sparse_shp.size(); i++) {
|
||||
auto s = sparse_shp[i];
|
||||
auto d = dense_shp[i];
|
||||
if (i < kCSRMulBatchPos) {
|
||||
if (d != s && d != 1) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << "Dense shape cannot broadcast to sparse shape.";
|
||||
}
|
||||
} else {
|
||||
if (d != s) {
|
||||
MS_EXCEPTION(mindspore::ValueError)
|
||||
<< "Currently, sparse shape and dense shape must equal in feature dimensions.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
|
||||
if (shape_size != expected_dim) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim
|
||||
<< "-dimensional tensor, but got a " << shape_size << "-dimensional tensor.";
|
||||
}
|
||||
}
|
||||
inline void CheckSparseIndicesDtype(const mindspore::TypePtr data_type, const std::string &arg_name) {
|
||||
if (!(data_type->equal(mindspore::kInt16) || data_type->equal(mindspore::kInt32) ||
|
||||
data_type->equal(mindspore::kInt64))) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
|
||||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
inline void CheckSparseIndicesDtypeInt32(const mindspore::TypePtr data_type, const std::string &arg_name) {
|
||||
if (!data_type->equal(mindspore::kInt32)) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
|
||||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
constexpr auto kCSRDenseShape = "dense_shape";
|
||||
constexpr auto kCSRAxis = "axis";
|
||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||
constexpr auto kIsCSR = "is_csr";
|
||||
|
||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// An object of a subclass of AbstractBase
|
||||
|
@ -343,433 +287,6 @@ AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const Primitive
|
|||
return args_spec_list[0];
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeCOOTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeThree);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
|
||||
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexTwo);
|
||||
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
CheckSparseIndicesDtype(indices_dtype, "Indices");
|
||||
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
|
||||
|
||||
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
|
||||
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
|
||||
<< "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero
|
||||
<< "]`: " << values_shp[kIndexZero];
|
||||
}
|
||||
constexpr int64_t kDimTwo = 2;
|
||||
if (indices_shp[kIndexOne] != kDimTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
|
||||
for (const auto &elem_type : dense_shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got "
|
||||
<< elem_type->ToString();
|
||||
}
|
||||
}
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
auto min_elem = *std::min_element(std::begin(shp), std::end(shp));
|
||||
if (min_elem <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem
|
||||
<< "int it";
|
||||
}
|
||||
ShapeVector dense_shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||
[](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
|
||||
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem <= 0) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
AbstractBasePtrList element_list{indices, values, dense_shape};
|
||||
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCOOTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
|
||||
return sparse_tensor->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCOOTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
|
||||
return sparse_tensor->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCOOTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto sparse_tensor = CheckArg<AbstractCOOTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->shape());
|
||||
return sparse_tensor->shape();
|
||||
}
|
||||
|
||||
ShapeVector ConvertToShapeVector(const AbstractTuplePtr &shape) {
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
|
||||
[](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
return shape_vec;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRElementwiseInputsNum = 5;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRElementwiseInputsNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
CheckSparseShape(sparse_shape, dense_shape);
|
||||
auto ret = values->Broaden();
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[0] / dense_shape[0];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr auto kCSRMVInputsNum = 5;
|
||||
constexpr auto kCSRMVShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRMVInputsNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
ShapeVector dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)"
|
||||
<< ", but its current shape is: "
|
||||
<< "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ").";
|
||||
}
|
||||
|
||||
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and an axis.
|
||||
constexpr auto kCSRReduceSumInputsNum = 5;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRReduceSumInputsNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
auto axis = CheckArg<AbstractScalar>(op_name, args_spec_list, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
ShapeVector out_shape = sparse_shape;
|
||||
MS_EXCEPTION_IF_NULL(axis->BuildValue());
|
||||
if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
|
||||
int64_t dim = static_cast<int64_t>(sparse_shape.size());
|
||||
if (axis_value != 1 && axis_value != 1 - dim) {
|
||||
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value
|
||||
<< "and `1- dim`: " << 1 - dim;
|
||||
}
|
||||
if (axis_value < 0) {
|
||||
axis_value += dim;
|
||||
}
|
||||
out_shape[LongToSize(axis_value)] = 1;
|
||||
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
|
||||
<< axis->BuildType()->ToString();
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(values->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[0] / sparse_shape[0];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
|
||||
constexpr size_t csr_row_num = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeFour);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
|
||||
auto sparse_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
MS_EXCEPTION_IF_NULL(sparse_shape);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
ShapeVector out_shape = indices->shape()->shape();
|
||||
MS_EXCEPTION_IF_NULL(dense->shape());
|
||||
ShapeVector dense_shape = dense->shape()->shape();
|
||||
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
|
||||
out_shape.push_back(dense_shape[i]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dense->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
|
||||
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(nnz);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
|
||||
out_shape.push_back(nnz_value);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->shape());
|
||||
auto num_rows = indptr->shape()->shape()[0] - 1;
|
||||
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto row_indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(row_indices);
|
||||
MS_EXCEPTION_IF_NULL(height);
|
||||
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
|
||||
MS_EXCEPTION_IF_NULL(height->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t height_value = GetValue<int64_t>(height->BuildValue());
|
||||
out_shape.push_back(height_value + 1);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(row_indices->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeCSRTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeFour);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
|
||||
|
||||
auto indptr_dtype = indptr->element()->BuildType();
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
CheckSparseIndicesDtype(indptr_dtype, "indptr");
|
||||
CheckSparseIndicesDtype(indices_dtype, "indices");
|
||||
|
||||
auto indptr_shp = indptr->shape()->shape();
|
||||
CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr");
|
||||
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: "
|
||||
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
|
||||
}
|
||||
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto shp = shape_value->value();
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
|
||||
<< "Values' dimension: " << values_shp.size()
|
||||
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
|
||||
}
|
||||
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
|
||||
}
|
||||
size_t shape_size = 1;
|
||||
auto shape_types = shape->ElementsType();
|
||||
for (size_t i = 0; i < shape_vec.size(); ++i) {
|
||||
if (shape_vec[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
|
||||
MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape.";
|
||||
}
|
||||
if (!shape_types[i]->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString();
|
||||
}
|
||||
shape_size *= LongToSize(shape_vec[i]);
|
||||
}
|
||||
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
|
||||
<< " non-zero values.";
|
||||
}
|
||||
AbstractBasePtrList element_list{indptr, indices, values, shape};
|
||||
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
return CheckArg<T>(op_name, args_spec_list, 0);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->values());
|
||||
return csr_tensor->values();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetIndptr(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
|
||||
return csr_tensor->indptr();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
|
||||
return csr_tensor->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplCSRTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->shape());
|
||||
return csr_tensor->shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplAllSwap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
@ -1008,48 +525,5 @@ AbstractBasePtr InferImplAdamApplyOneWithDecay(const AnalysisEnginePtr &, const
|
|||
AbstractBasePtrList rets = {add1, add0, sub0};
|
||||
return std::make_shared<AbstractTuple>(rets);
|
||||
}
|
||||
AbstractBasePtr InferImplCSRMM(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRMMInputsNum = 5;
|
||||
constexpr auto kCSRMMShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRMMInputsNum);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
auto shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 3);
|
||||
auto dense = CheckArg<AbstractTensor>(op_name, args_spec_list, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimensions, ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
|
||||
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
|
||||
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
|
||||
<< " and csr tensor's shape[1] is " << sparse_shape[kIndexOne];
|
||||
}
|
||||
|
||||
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -318,31 +318,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimDebug, R{InferImplDebug, nullptr, true}},
|
||||
// Dynamic shape testing
|
||||
{prim::kPrimGpuConvertToDynamicShape, R{InferImplGpuConvertToDynamicShape, nullptr, true}},
|
||||
// COOTensor
|
||||
{prim::kPrimMakeCOOTensor, R{InferImplMakeCOOTensor, nullptr, true}},
|
||||
{prim::kPrimCOOTensorGetValues, R{InferImplCOOTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimCOOTensorGetIndices, R{InferImplCOOTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimCOOTensorGetDenseShape, R{InferImplCOOTensorGetDenseShape, nullptr, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, R{InferImplMakeRowTensor, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetValues, R{InferImplRowTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetIndices, R{InferImplRowTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, R{InferImplRowTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimRowTensorAdd, R{InferImplRowTensorAdd, nullptr, false}},
|
||||
// CSRTensor
|
||||
{prim::kPrimMakeCSRTensor, R{InferImplMakeCSRTensor, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetValues, R{InferImplCSRTensorGetValues, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetIndptr, R{InferImplCSRTensorGetIndptr, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetIndices, R{InferImplCSRTensorGetIndices, nullptr, true}},
|
||||
{prim::kPrimCSRTensorGetDenseShape, R{InferImplCSRTensorGetDenseShape, nullptr, true}},
|
||||
{prim::kPrimCSRMul, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRDiv, R{InferImplCSRElementWise, nullptr, true}},
|
||||
{prim::kPrimCSRMV, R{InferImplCSRMV, nullptr, true}},
|
||||
{prim::kPrimCSRMM, R{InferImplCSRMM, nullptr, true}},
|
||||
{prim::kPrimCSRReduceSum, R{InferImplCSRReduceSum, nullptr, true}},
|
||||
{prim::kPrimCSRGather, R{InferImplCSRGather, nullptr, true}},
|
||||
{prim::kPrimCSR2COO, R{InferImplCSR2COO, nullptr, true}},
|
||||
{prim::kPrimCOO2CSR, R{InferImplCOO2CSR, nullptr, true}},
|
||||
// Comm Ops
|
||||
{prim::kPrimAllSwap, R{InferImplAllSwap, nullptr, true}},
|
||||
{prim::kPrimMemCpyAsync, R{InferImplMemCpyAsync, nullptr, true}},
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/coo_tensor_get_indices.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(coo_tensor->indices());
|
||||
return coo_tensor->indices();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(COOTensorGetIndices, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetIndices, prim::kPrimCOOTensorGetIndices, COOTensorGetIndicesInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_
|
||||
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCOOTensorGetIndices = "COOTensorGetIndices";
|
||||
class MIND_API COOTensorGetIndices : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(COOTensorGetIndices);
|
||||
/// \brief Constructor.
|
||||
COOTensorGetIndices() : BaseOperator(kNameCOOTensorGetIndices) {}
|
||||
};
|
||||
abstract::AbstractBasePtr COOTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_INDICES_H_
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/coo_tensor_get_shape.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(coo_tensor->shape());
|
||||
return coo_tensor->shape();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(COOTensorGetShape, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetShape, prim::kPrimCOOTensorGetDenseShape, COOTensorGetShapeInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_
|
||||
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCOOTensorGetShape = "COOTensorGetShape";
|
||||
class MIND_API COOTensorGetShape : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(COOTensorGetShape);
|
||||
/// \brief Constructor.
|
||||
COOTensorGetShape() : BaseOperator(kNameCOOTensorGetShape) {}
|
||||
};
|
||||
abstract::AbstractBasePtr COOTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_SHAPE_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/coo_tensor_get_values.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto coo_tensor = InferSparseAttr<abstract::AbstractCOOTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(coo_tensor->values());
|
||||
return coo_tensor->values();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(COOTensorGetValues, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(COOTensorGetValues, prim::kPrimCOOTensorGetValues, COOTensorGetValuesInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_
|
||||
#define MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCOOTensorGetValues = "COOTensorGetValues";
|
||||
class MIND_API COOTensorGetValues : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(COOTensorGetValues);
|
||||
/// \brief Constructor.
|
||||
COOTensorGetValues() : BaseOperator(kNameCOOTensorGetValues) {}
|
||||
};
|
||||
abstract::AbstractBasePtr COOTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_COOTENSOR_GET_VALUES_H_
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/coo_to_csr.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractScalar;
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: the row indices of a sparse coo tensor, and the size of its first dimension.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRArgsSize);
|
||||
auto row_indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
auto height = abstract::CheckArg<AbstractScalar>(op_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(row_indices);
|
||||
MS_EXCEPTION_IF_NULL(height);
|
||||
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
|
||||
MS_EXCEPTION_IF_NULL(height->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t height_value = GetValue<int64_t>(height->BuildValue());
|
||||
out_shape.push_back(height_value + 1);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer height.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(row_indices->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(row_indices->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(COO2CSR, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(COO2CSR, prim::kPrimCOO2CSR, COO2CSRInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_COO_TO_CSR
|
||||
#define MINDSPORE_CORE_OPS_COO_TO_CSR
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCOO2CSR = "COO2CSR";
|
||||
/// \brief Converts the row indices of a COOTensor to the indptr of a CSRTensor.
|
||||
class MIND_API COO2CSR : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(COO2CSR);
|
||||
/// \brief Constructor.
|
||||
COO2CSR() : BaseOperator(kNameCOO2CSR) { InitIOName({"row_indices", "height"}, {"output"}); }
|
||||
};
|
||||
abstract::AbstractBasePtr COO2CSRInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_COO_TO_CSR
|
|
@ -0,0 +1,68 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_elementwise.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRElementwiseInputsNum = 5;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRElementwiseInputsNum);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
|
||||
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
|
||||
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
CheckSparseShape(sparse_shape, dense_shape);
|
||||
auto ret = values->Broaden();
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[0] / dense_shape[0];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRMul, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMul, prim::kPrimCSRMul, CSRElementWiseInfer, nullptr, true);
|
||||
MIND_API_OPERATOR_IMPL(CSRDiv, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRDiv, prim::kPrimCSRDiv, CSRElementWiseInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_ELEMENTWISE
|
||||
#define MINDSPORE_CORE_OPS_CSR_ELEMENTWISE
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRMul = "CSRMul";
|
||||
constexpr auto kNameCSRDiv = "CSRDiv";
|
||||
/// \brief CSRTensor elementwise operation.
|
||||
class MIND_API CSRMul : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRMul);
|
||||
/// \brief Constructor.
|
||||
CSRMul() : BaseOperator(kNameCSRMul) {
|
||||
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
|
||||
}
|
||||
};
|
||||
|
||||
class MIND_API CSRDiv : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRDiv);
|
||||
/// \brief Constructor.
|
||||
CSRDiv() : BaseOperator(kNameCSRDiv) {
|
||||
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
|
||||
}
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr CSRElementWiseInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_ELEMENTWISE
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_gather.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: the indptr and indices of a sparse csr tensor, a dense tensor, and the shape of the sparse tensor.
|
||||
constexpr size_t csr_row_num = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
abstract::CheckArgsSize(op_name, input_args, kSizeFour);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexZero);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexOne);
|
||||
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, kIndexTwo);
|
||||
auto sparse_shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, kIndexThree);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
MS_EXCEPTION_IF_NULL(sparse_shape);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
auto shape_value = sparse_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
int64_t csr_avg_rows = nnz_vec[0] / GetValue<int64_t>(shape_value->value()[0]);
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indices->shape());
|
||||
ShapeVector out_shape = indices->shape()->shape();
|
||||
MS_EXCEPTION_IF_NULL(dense->shape());
|
||||
ShapeVector dense_shape = dense->shape()->shape();
|
||||
for (size_t i = csr_row_num; i < dense_shape.size(); ++i) {
|
||||
out_shape.push_back(dense_shape[i]);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dense->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(dense->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRGather, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRGather, prim::kPrimCSRGather, CSRGatherInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_GATHER
|
||||
#define MINDSPORE_CORE_OPS_CSR_GATHER
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRGather = "CSRGather";
|
||||
/// \brief Returns the values of a CSRTensor indexed from a dense tensor using indptr and indices.
|
||||
class MIND_API CSRGather : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRGather);
|
||||
/// \brief Constructor.
|
||||
CSRGather() : BaseOperator(kNameCSRGather) { InitIOName({"indptr", "indices", "dense", "dense_shape"}, {"output"}); }
|
||||
};
|
||||
abstract::AbstractBasePtr CSRGatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_GATHER
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_mm.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: a sparse tensor and a dense tensor.
|
||||
constexpr auto kCSRMMInputsNum = 5;
|
||||
constexpr auto kCSRMMShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRMMInputsNum);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
|
||||
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
|
||||
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMMShapeSize || dense_shape.size() != kCSRMMShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMMShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne]) {
|
||||
MS_EXCEPTION(ValueError) << "The dense's shape[0] should be equal to csr tensor's shape[1]"
|
||||
<< ", but dense's shape[0] is: " << dense_shape[kIndexZero]
|
||||
<< " and csr tensor's shape[1] is " << sparse_shape[kIndexOne];
|
||||
}
|
||||
|
||||
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRMM, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMM, prim::kPrimCSRMM, CSRMMInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_MM
|
||||
#define MINDSPORE_CORE_OPS_CSR_MM
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRMM = "CSRMM";
|
||||
/// \brief Sparse matrix-matrix multiplication.
|
||||
class MIND_API CSRMM : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRMM);
|
||||
/// \brief Constructor.
|
||||
CSRMM() : BaseOperator(kNameCSRMM) {
|
||||
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
|
||||
}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRMMInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_MM
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_mv.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
constexpr auto kCSRMVInputsNum = 5;
|
||||
constexpr auto kCSRMVShapeSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRMVInputsNum);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
|
||||
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
|
||||
auto dense = abstract::CheckArg<AbstractTensor>(op_name, input_args, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
ShapeVector dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRMVShapeSize << "-D inputs! "
|
||||
<< "But csr tensor has " << sparse_shape.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shape.size() << " dimension(s). ";
|
||||
}
|
||||
if (dense_shape[kIndexZero] != sparse_shape[kIndexOne] || dense_shape[kIndexOne] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "The dense_vector's shape should be (" << sparse_shape[kIndexOne] << ", 1)"
|
||||
<< ", but its current shape is: "
|
||||
<< "(" << dense_shape[kIndexZero] << ", " << dense_shape[kIndexOne] << ").";
|
||||
}
|
||||
|
||||
ShapeVector out_shape = {sparse_shape[kIndexZero], dense_shape[kIndexOne]};
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[kIndexZero] / dense_shape[kIndexZero];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRMV, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRMV, prim::kPrimCSRMV, CSRMVInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_MV
|
||||
#define MINDSPORE_CORE_OPS_CSR_MV
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRMV = "CSRMV";
|
||||
/// \brief Sparse matrix-vector multiplication.
|
||||
class MIND_API CSRMV : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRMV);
|
||||
/// \brief Constructor.
|
||||
CSRMV() : BaseOperator(kNameCSRMV) {
|
||||
InitIOName({"indptr", "indices", "values", "dense_shape", "dense_tensor"}, {"output"});
|
||||
}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRMVInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_MV
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_reducesum.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractScalar;
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: a sparse tensor and an axis.
|
||||
constexpr auto kCSRReduceSumInputsNum = 5;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRReduceSumInputsNum);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, input_args, 1);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, input_args, 2);
|
||||
auto shape = abstract::CheckArg<AbstractTuple>(op_name, input_args, 3);
|
||||
auto axis = abstract::CheckArg<AbstractScalar>(op_name, input_args, 4);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(indices);
|
||||
MS_EXCEPTION_IF_NULL(values);
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
ShapeVector sparse_shape = ConvertToShapeVector(shape);
|
||||
ShapeVector out_shape = sparse_shape;
|
||||
MS_EXCEPTION_IF_NULL(axis->BuildValue());
|
||||
if (axis->BuildValue()->isa<Int32Imm>() || axis->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
|
||||
int64_t dim = static_cast<int64_t>(sparse_shape.size());
|
||||
if (axis_value != 1 && axis_value != 1 - dim) {
|
||||
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be 1 or 1-dim. But got `axis`: " << axis_value
|
||||
<< "and `1- dim`: " << 1 - dim << ".";
|
||||
}
|
||||
if (axis_value < 0) {
|
||||
axis_value += dim;
|
||||
}
|
||||
out_shape[LongToSize(axis_value)] = 1;
|
||||
primitive->set_attr(kCSRAxis, MakeValue(axis_value));
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be int32 or int64, but got "
|
||||
<< axis->BuildType()->ToString() << ".";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(values->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(values->element()->BuildType(), out_shape);
|
||||
// SetAttr
|
||||
auto nnz_vec = indices->shape()->shape();
|
||||
auto csr_avg_rows = nnz_vec[0] / sparse_shape[0];
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRReduceSum, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRReduceSum, prim::kPrimCSRReduceSum, CSRReduceSumInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_REDUCESUM
|
||||
#define MINDSPORE_CORE_OPS_CSR_REDUCESUM
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRReduceSum = "CSRReduceSum";
|
||||
/// \brief CSRTensor reducesum.
|
||||
class MIND_API CSRReduceSum : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRReduceSum);
|
||||
/// \brief Constructor.
|
||||
CSRReduceSum() : BaseOperator(kNameCSRReduceSum) {
|
||||
InitIOName({"indptr", "indices", "values", "dense_shape", "axis"}, {"output"});
|
||||
}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_REDUCESUM
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_tensor_get_indices.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indices());
|
||||
return csr_tensor->indices();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRTensorGetIndices, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndices, prim::kPrimCSRTensorGetIndices, CSRTensorGetIndicesInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_
|
||||
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRTensorGetIndices = "CSRTensorGetIndices";
|
||||
/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor.
|
||||
class MIND_API CSRTensorGetIndices : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRTensorGetIndices);
|
||||
/// \brief Constructor.
|
||||
CSRTensorGetIndices() : BaseOperator(kNameCSRTensorGetIndices) {}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRTensorGetIndicesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDICES_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_tensor_get_indptr.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->indptr());
|
||||
return csr_tensor->indptr();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRTensorGetIndptr, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetIndptr, prim::kPrimCSRTensorGetIndptr, CSRTensorGetIndptrInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_
|
||||
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRTensorGetIndptr = "CSRTensorGetIndptr";
|
||||
/// \brief CSRTensorGetIndptr op is used to get indptr in CSRTensor.
|
||||
class MIND_API CSRTensorGetIndptr : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRTensorGetIndptr);
|
||||
/// \brief Constructor.
|
||||
CSRTensorGetIndptr() : BaseOperator(kNameCSRTensorGetIndptr) {}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRTensorGetIndptrInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_INDPTR_H_
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_tensor_get_shape.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->shape());
|
||||
return csr_tensor->shape();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRTensorGetShape, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetShape, prim::kPrimCSRTensorGetDenseShape, CSRTensorGetShapeInfer, nullptr,
|
||||
true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_
|
||||
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRTensorGetShape = "CSRTensorGetShape";
|
||||
class MIND_API CSRTensorGetShape : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRTensorGetShape);
|
||||
/// \brief Constructor.
|
||||
CSRTensorGetShape() : BaseOperator(kNameCSRTensorGetShape) {}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRTensorGetShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_SHAPE_H_
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_tensor_get_values.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
auto csr_tensor = InferSparseAttr<abstract::AbstractCSRTensor>(primitive, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(csr_tensor->values());
|
||||
return csr_tensor->values();
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSRTensorGetValues, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSRTensorGetValues, prim::kPrimCSRTensorGetValues, CSRTensorGetValuesInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_
|
||||
#define MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSRTensorGetValues = "CSRTensorGetValues";
|
||||
/// \brief CSRTensorGetIndices op is used to get indptr in CSRTensor.
|
||||
class MIND_API CSRTensorGetValues : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSRTensorGetValues);
|
||||
/// \brief Constructor.
|
||||
CSRTensorGetValues() : BaseOperator(kNameCSRTensorGetValues) {}
|
||||
};
|
||||
abstract::AbstractBasePtr CSRTensorGetValuesInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSRTENSOR_GET_VALUES_H_
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/csr_to_coo.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractScalar;
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
// Inputs: the indptr of a sparse csr tensor, and the number of non-zero elements.
|
||||
constexpr auto kCSRArgsSize = 2;
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, input_args, kCSRArgsSize);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, input_args, 0);
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
|
||||
auto nnz = abstract::CheckArg<AbstractScalar>(op_name, input_args, 1);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(nnz);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(nnz->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (nnz->BuildValue()->isa<Int32Imm>() || nnz->BuildValue()->isa<Int64Imm>()) {
|
||||
int64_t nnz_value = GetValue<int64_t>(nnz->BuildValue());
|
||||
out_shape.push_back(nnz_value);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support Integer nnz.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->shape());
|
||||
auto num_rows = indptr->shape()->shape()[0] - 1;
|
||||
int csr_avg_rows = GetValue<int64_t>(nnz->BuildValue()) / num_rows;
|
||||
primitive->set_attr(kCSRAvgRows, MakeValue(csr_avg_rows));
|
||||
primitive->set_attr(kIsCSR, MakeValue(true));
|
||||
|
||||
MS_EXCEPTION_IF_NULL(indptr->element());
|
||||
auto ret = std::make_shared<AbstractTensor>(indptr->element()->BuildType(), out_shape);
|
||||
return ret;
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(CSR2COO, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CSR2COO, prim::kPrimCSR2COO, CSR2COOInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_CSR_TO_COO
|
||||
#define MINDSPORE_CORE_OPS_CSR_TO_COO
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameCSR2COO = "CSR2COO";
|
||||
/// \brief Converts the indptr of a CSRTensor to the row indices of a COOTensor.
|
||||
class MIND_API CSR2COO : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(CSR2COO);
|
||||
/// \brief Constructor.
|
||||
CSR2COO() : BaseOperator(kNameCSR2COO) { InitIOName({"indptr", "nnz"}, {"output"}); }
|
||||
};
|
||||
abstract::AbstractBasePtr CSR2COOInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_CSR_TO_COO
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/make_cootensor.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeThree);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
|
||||
auto dense_shape = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexTwo);
|
||||
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
CheckSparseIndicesDtype(indices_dtype, "Indices");
|
||||
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
|
||||
|
||||
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
|
||||
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
|
||||
<< "]`: " << indices_shp[kIndexZero] << " and `values.shape[" << kIndexZero
|
||||
<< "]`: " << values_shp[kIndexZero];
|
||||
}
|
||||
constexpr int64_t kDimTwo = 2;
|
||||
if (indices_shp[kIndexOne] != kDimTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexOne << "]` must be " << kDimTwo << ",but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
|
||||
for (const auto &elem_type : dense_shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element type of `shape` must be Int, but got "
|
||||
<< elem_type->ToString();
|
||||
}
|
||||
}
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
auto min_elem = *std::min_element(std::begin(shp), std::end(shp));
|
||||
if (min_elem <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, the element of `shape` must be positive integer. But got " << min_elem
|
||||
<< "int it";
|
||||
}
|
||||
ShapeVector dense_shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||
[](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
|
||||
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem <= 0) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
AbstractBasePtrList element_list{indices, values, dense_shape};
|
||||
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(MakeCOOTensor, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeCOOTensor, prim::kPrimMakeCOOTensor, MakeCOOTensorInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_
|
||||
#define MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMakeCOOTensor = "MakeCOOTensor";
|
||||
class MIND_API MakeCOOTensor : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MakeCOOTensor);
|
||||
/// \brief Constructor.
|
||||
MakeCOOTensor() : BaseOperator(kNameMakeCOOTensor) {}
|
||||
};
|
||||
abstract::AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MAKE_COOTENSOR_H_
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/make_csrtensor.h"
|
||||
|
||||
#include "abstract/dshape.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
using abstract::AbstractTensor;
|
||||
using abstract::AbstractTuple;
|
||||
AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list) {
|
||||
// Inputs: three tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kSizeFour);
|
||||
auto indptr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexZero);
|
||||
auto indices = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexOne);
|
||||
auto values = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, kIndexTwo);
|
||||
auto shape = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, kIndexThree);
|
||||
|
||||
auto indptr_dtype = indptr->element()->BuildType();
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
CheckSparseIndicesDtype(indptr_dtype, "indptr");
|
||||
CheckSparseIndicesDtype(indices_dtype, "indices");
|
||||
|
||||
auto indptr_shp = indptr->shape()->shape();
|
||||
CheckSparseShape(indptr_shp.size(), kSizeOne, "Indptr");
|
||||
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indices and values must have same size, but got: values length: "
|
||||
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
|
||||
}
|
||||
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto shp = shape_value->value();
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to csr_tensor's dimension - 1, but got"
|
||||
<< "Values' dimension: " << values_shp.size()
|
||||
<< ", csr_tensor's dimension: " << shape_vec.size() << ".";
|
||||
}
|
||||
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
|
||||
}
|
||||
size_t shape_size = 1;
|
||||
auto shape_types = shape->ElementsType();
|
||||
for (size_t i = 0; i < shape_vec.size(); ++i) {
|
||||
if (shape_vec[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
|
||||
MS_EXCEPTION(ValueError) << "csr_tensor's shape should match with values' shape.";
|
||||
}
|
||||
if (!shape_types[i]->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << shape_types[i]->ToString();
|
||||
}
|
||||
shape_size *= LongToSize(shape_vec[i]);
|
||||
}
|
||||
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
|
||||
<< " non-zero values.";
|
||||
}
|
||||
std::vector<abstract::AbstractBasePtr> element_list{indptr, indices, values, shape};
|
||||
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(MakeCSRTensor, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MakeCSRTensor, prim::kPrimMakeCSRTensor, MakeCSRTensorInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_
|
||||
#define MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameMakeCSRTensor = "MakeCSRTensor";
|
||||
/// \brief MakeCSRTensor op is used to construct CSRTensor.
|
||||
class MIND_API MakeCSRTensor : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(MakeCSRTensor);
|
||||
/// \brief Constructor.
|
||||
MakeCSRTensor() : BaseOperator(kNameMakeCSRTensor) {}
|
||||
};
|
||||
abstract::AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_MAKE_CSRTENSOR_H_
|
|
@ -466,5 +466,99 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP
|
|||
ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
|
||||
return EvalShapeTensorValue(prim, args, true);
|
||||
}
|
||||
|
||||
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp) {
|
||||
constexpr auto kCSRMulBatchPos = 2;
|
||||
int dlen = SizeToInt(sparse_shp.size()) - SizeToInt(dense_shp.size());
|
||||
if (dlen < 0) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support dense tensor broadcast to sparse tensor, "
|
||||
<< "but sparse tensor has " << sparse_shp.size() << " dimensions, "
|
||||
<< "and dense tensor has " << dense_shp.size() << " dimensions. ";
|
||||
}
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)dense_shp.insert(dense_shp.begin(), 1);
|
||||
}
|
||||
if (sparse_shp.size() != dense_shp.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: sparse_shp.size() != dense_shp.size().";
|
||||
}
|
||||
if (sparse_shp.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Failure: dense tensor and sparse tensor shapes cannot be zero.";
|
||||
}
|
||||
for (size_t i = 0; i < sparse_shp.size(); i++) {
|
||||
auto s = sparse_shp[i];
|
||||
auto d = dense_shp[i];
|
||||
if (i < kCSRMulBatchPos) {
|
||||
if (d != s && d != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Dense shape cannot broadcast to sparse shape.";
|
||||
}
|
||||
} else {
|
||||
if (d != s) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, sparse shape and dense shape must equal in feature dimensions.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name) {
|
||||
if (shape_size != expected_dim) {
|
||||
MS_EXCEPTION(ValueError) << arg_name << " must be a " << expected_dim << "-dimensional tensor, but got a "
|
||||
<< shape_size << "-dimensional tensor.";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name) {
|
||||
if (!(data_type->equal(kInt16) || data_type->equal(kInt32) || data_type->equal(kInt64))) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
|
||||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name) {
|
||||
if (!data_type->equal(kInt32)) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
|
||||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
|
||||
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape) {
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shape_value->value()), std::end(shape_value->value()), std::back_inserter(shape_vec),
|
||||
[](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
return shape_vec;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
constexpr size_t kSizeExpect = 1;
|
||||
if (args_spec_list.size() != kSizeExpect) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the number of input should be " << kSizeExpect
|
||||
<< ", but got " << args_spec_list.size() << ".";
|
||||
}
|
||||
constexpr size_t kIndex = 0;
|
||||
auto abs = args_spec_list[kIndex];
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
// To avoid AbstractSparseTensors being generalized to AbstractTuple.
|
||||
if (dyn_cast<T>(abs) == nullptr) {
|
||||
auto abs_tuple = dyn_cast<abstract::AbstractTuple>(abs);
|
||||
if (abs_tuple != nullptr) {
|
||||
return std::make_shared<T>(abs_tuple->elements());
|
||||
}
|
||||
} else if (dyn_cast<T>(abs) != nullptr) {
|
||||
return dyn_cast<T>(abs);
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "For \'" << primitive->name() << "\', input[" << kIndex
|
||||
<< "] should be AbstractSparseTensor or AbstractTuple, but got "
|
||||
<< abs->BuildType()->ToString() << ".";
|
||||
}
|
||||
template std::shared_ptr<abstract::AbstractCSRTensor> InferSparseAttr(const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
template std::shared_ptr<abstract::AbstractCOOTensor> InferSparseAttr(const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,5 +80,23 @@ ValuePtr InferMakeShapeTensorValue(const PrimitivePtr &prim, const AbstractBaseP
|
|||
// Infer shape value of compute-shape op that could change the dim value, e.g. Mul, Add, Sub
|
||||
// Do not support op with multiple outputs for now
|
||||
ValuePtr InferComputeShapeTensorValue(const PrimitivePtr &prim, const AbstractBasePtrList &args);
|
||||
|
||||
void CheckSparseShape(ShapeVector sparse_shp, ShapeVector dense_shp);
|
||||
|
||||
void CheckSparseShape(const size_t shape_size, const size_t expected_dim, const std::string &arg_name);
|
||||
|
||||
void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_name);
|
||||
|
||||
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
|
||||
|
||||
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> InferSparseAttr(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
constexpr auto kCSRAvgRows = "csr_avg_rows";
|
||||
constexpr auto kIsCSR = "is_csr";
|
||||
constexpr auto kCSRDenseShape = "dense_shape";
|
||||
constexpr auto kCSRAxis = "axis";
|
||||
} // namespace mindspore::ops
|
||||
#endif // MINDSPORE_CORE_OPS_OP_UTILS_H
|
||||
|
|
|
@ -46,21 +46,6 @@ constexpr size_t kAlphaIndex = 10;
|
|||
constexpr size_t kBetaIndex = 11;
|
||||
constexpr int64_t kDefaultRank = 2;
|
||||
constexpr int64_t kBatchedRank = 3;
|
||||
|
||||
inline void CheckSparseShape(const size_t sparse_shape_size, const size_t expected_dim, const std::string &arg_name) {
|
||||
if (sparse_shape_size != expected_dim) {
|
||||
MS_EXCEPTION(mindspore::ValueError) << arg_name << " must be a " << expected_dim
|
||||
<< "-dimensional tensor, but got a " << sparse_shape_size
|
||||
<< "-dimensional tensor.";
|
||||
}
|
||||
}
|
||||
|
||||
inline void CheckSparseIndicesDtype(const mindspore::TypePtr dtype, const std::string &arg_name) {
|
||||
if (!(dtype->equal(mindspore::kInt16) || dtype->equal(mindspore::kInt32) || dtype->equal(mindspore::kInt64))) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " must be Int16 or Int32 or Int64, but got "
|
||||
<< dtype->ToString() << ".";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
void SparseMatrixAdd::set_dense_shape(const std::vector<int64_t> &shape) {
|
||||
(void)this->AddAttr(kDenseShape, api::MakeValue(shape));
|
||||
|
|
Loading…
Reference in New Issue