forked from mindspore-Ecosystem/mindspore
!46741 support dense to sparse coo dynamic rank
Merge pull request !46741 from 杨林枫/coo_to_sparse_drank
This commit is contained in:
commit
ab4d555d43
|
@ -60,6 +60,8 @@ int DenseToCSRSparseMatrixCpuKernelMod::Resize(const BaseOperatorPtr &base_opera
|
|||
batch_size_ = (rank_ == kDefaultRank) ? kOne : dense_shape[kZero];
|
||||
num_rows_ = (rank_ == kDefaultRank) ? dense_shape[kZero] : dense_shape[kOne];
|
||||
num_cols_ = (rank_ == kDefaultRank) ? dense_shape[kOne] : dense_shape[kTwo];
|
||||
total_ele_ = (rank_ == kDefaultRank) ? dense_shape[kZero] * dense_shape[kOne]
|
||||
: dense_shape[kZero] * dense_shape[kOne] * dense_shape[kTwo];
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -116,6 +118,13 @@ bool DenseToCSRSparseMatrixCpuKernelMod::Launch(const std::vector<kernel::Addres
|
|||
return true;
|
||||
}
|
||||
|
||||
inline void CheckIndicesInRange(const size_t total_ele, const size_t idx, const std::string &name) {
|
||||
if (idx >= total_ele) {
|
||||
MS_LOG(EXCEPTION) << "For '" << name << "', flattened index must in range: [0, " << total_ele
|
||||
<< "), but got: " << idx << ".";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename indiceT, typename valueT>
|
||||
void DenseToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) const {
|
||||
|
@ -136,12 +145,14 @@ void DenseToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector<AddressP
|
|||
}
|
||||
for (size_t i = kZero; i < total_nnz_; i++) {
|
||||
if (rank_ == kDefaultRank) {
|
||||
auto cur_idx = indices_ptr[i * rank_] * indiceT(num_cols_) + indices_ptr[i * rank_ + kOne];
|
||||
y_values_ptr[i] = dense_input_ptr[LongToSize(cur_idx)];
|
||||
auto cur_idx = LongToSize(indices_ptr[i * rank_] * indiceT(num_cols_) + indices_ptr[i * rank_ + kOne]);
|
||||
CheckIndicesInRange(total_ele_, cur_idx, kernel_name_);
|
||||
y_values_ptr[i] = dense_input_ptr[cur_idx];
|
||||
} else {
|
||||
auto cur_idx = indices_ptr[i * rank_] * indiceT(num_rows_) * indiceT(num_cols_) +
|
||||
indices_ptr[i * rank_ + kOne] * indiceT(num_cols_) + indices_ptr[i * rank_ + kTwo];
|
||||
y_values_ptr[i] = dense_input_ptr[LongToSize(cur_idx)];
|
||||
auto cur_idx = LongToSize(indices_ptr[i * rank_] * indiceT(num_rows_) * indiceT(num_cols_) +
|
||||
indices_ptr[i * rank_ + kOne] * indiceT(num_cols_) + indices_ptr[i * rank_ + kTwo]);
|
||||
CheckIndicesInRange(total_ele_, cur_idx, kernel_name_);
|
||||
y_values_ptr[i] = dense_input_ptr[cur_idx];
|
||||
}
|
||||
}
|
||||
for (size_t i = kZero; i < batch_size_ * (num_rows_ + kOne); i++) {
|
||||
|
|
|
@ -52,6 +52,7 @@ class DenseToCSRSparseMatrixCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t num_rows_{0};
|
||||
size_t num_cols_{0};
|
||||
size_t total_nnz_{0};
|
||||
size_t total_ele_{0};
|
||||
TypeId values_type_{kTypeUnknown};
|
||||
TypeId indices_type_{kTypeUnknown};
|
||||
};
|
||||
|
|
|
@ -34,6 +34,7 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
|
|||
const int64_t kOne = 1;
|
||||
const int64_t kDefalutRank = 2;
|
||||
const int64_t kBatchRank = 3;
|
||||
CheckInputShapeEmpty(primitive->name(), input_args);
|
||||
auto d_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto b_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto r_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
|
@ -51,8 +52,9 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
|
|||
<< values_shape.size() << ".";
|
||||
}
|
||||
// Dynamic Rank
|
||||
if (IsDynamicRank(d_shape_shape) || IsDynamicRank(b_ptrs_shape) || IsDynamicRank(r_ptrs_shape) ||
|
||||
IsDynamicRank(c_ind_shape) || IsDynamicRank(values_shape)) {
|
||||
std::vector<ShapeVector> tensor_shapes{d_shape_shape, c_ind_shape, values_shape, r_ptrs_shape, b_ptrs_shape};
|
||||
if (std::any_of(tensor_shapes.cbegin(), tensor_shapes.cend(),
|
||||
[](const ShapeVector shp) { return IsDynamicRank(shp); })) {
|
||||
ShapeVector dense_shape = {-2};
|
||||
return std::make_shared<abstract::Shape>(dense_shape);
|
||||
}
|
||||
|
@ -70,7 +72,7 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
|
|||
return std::make_shared<abstract::Shape>(dense_shape);
|
||||
}
|
||||
// Static Shape
|
||||
if (values_shape[kZero] != c_ind_shape[kZero]) {
|
||||
if (!IsDynamic(values_shape) && !IsDynamic(c_ind_shape) && values_shape[kZero] != c_ind_shape[kZero]) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'col_indices' and 'values' "
|
||||
<< "should have the same length.";
|
||||
}
|
||||
|
@ -95,7 +97,8 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
|
|||
if (rank == kBatchRank) {
|
||||
batch_size = d_shape_value_ptr_tensor[kZero], row_num = d_shape_value_ptr_tensor[kOne];
|
||||
}
|
||||
if (b_ptrs_shape[kZero] != (batch_size + kOne) || r_ptrs_shape[kZero] != batch_size * (row_num + kOne)) {
|
||||
if (!IsDynamic(b_ptrs_shape) && !IsDynamic(r_ptrs_shape) &&
|
||||
(b_ptrs_shape[kZero] != (batch_size + kOne) || r_ptrs_shape[kZero] != batch_size * (row_num + kOne))) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', batch size of the input is " << batch_size
|
||||
<< ", row numbers of the input is " << row_num << ", so shape of 'x_batch_pointers' "
|
||||
<< "should be (" << batch_size + kOne << "), but got (" << b_ptrs_shape[kZero] << ")"
|
||||
|
|
|
@ -59,6 +59,7 @@ abstract::TupleShapePtr CSRSparseMatrixToSparseTensorInferShape(const PrimitiveP
|
|||
const int64_t kOne = 1;
|
||||
const int64_t kDefalutRank = 2;
|
||||
const int64_t kBatchRank = 3;
|
||||
CheckInputShapeEmpty(primitive->name(), input_args);
|
||||
std::vector<int64_t> x_dense_shape_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto prim_name = primitive->name();
|
||||
|
|
|
@ -29,6 +29,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::TupleShapePtr DenseToCSRSparseMatrixInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckInputShapeEmpty(primitive->name(), input_args);
|
||||
auto dense_input_shape_ptr = input_args[kInputIndex0]->BuildShape();
|
||||
auto dense_input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(dense_input_shape_ptr);
|
||||
auto dense_input_shape = dense_input_shape_map[kShape];
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -62,15 +63,17 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
}
|
||||
|
||||
// Convert dense_shape from tuple to shapevector(dense_shape_vec)
|
||||
auto dense_shape_vec = GetShapeValue(primitive, dense_shape);
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
ShapeVector dense_shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||
[](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
if (!IsDynamic(dense_shape_vec)) {
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem <= 0) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (IsDynamic(indices_shp) || IsDynamic(values_shp)) {
|
||||
MS_LOG(DEBUG) << "Dynamic shape in MakeCOOTensor's inputs! Ignore shape check.";
|
||||
|
@ -78,6 +81,9 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
|
||||
}
|
||||
|
||||
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
|
||||
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
|
||||
|
||||
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
|
||||
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
|
||||
|
@ -90,17 +96,11 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
|
||||
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
|
||||
if (!IsDynamicRank(dense_shape_vec) && LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
|
||||
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
|
||||
<< indices_shp[kIndexOne];
|
||||
}
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem <= 0) {
|
||||
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
AbstractBasePtrList element_list{indices, values, dense_shape};
|
||||
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "ops/primitive_c.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -55,6 +56,7 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto indices_shp = indices->shape()->shape();
|
||||
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
for (const auto &elem_type : shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << elem_type->ToString();
|
||||
|
@ -62,21 +64,8 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
}
|
||||
|
||||
// convert shape from tuple to shapevector(shape_vec)
|
||||
auto shape_vec = GetShapeValue(primitive, shape);
|
||||
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
auto shp = shape_value->value();
|
||||
ShapeVector shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
|
||||
auto elem = GetValue<int64_t>(e);
|
||||
return elem;
|
||||
});
|
||||
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to CSRTensor's dimension - 1, but got"
|
||||
<< "Values' dimension: " << values_shp.size()
|
||||
<< ", CSRTensor's dimension: " << shape_vec.size() << ".";
|
||||
}
|
||||
|
||||
if (IsShapeEmpty(indptr_shp) && IsShapeEmpty(indices_shp) && IsShapeEmpty(values_shp)) {
|
||||
MS_LOG(DEBUG) << "Constructing empty CSRTensor! Ignore further shape check.";
|
||||
|
@ -90,7 +79,31 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
|
||||
}
|
||||
|
||||
if (values_shp.size() + 1 != shape_vec.size()) {
|
||||
if (!IsDynamic(shape_vec)) {
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
size_t shape_size = 1;
|
||||
for (size_t i = 0; i < shape_vec.size(); ++i) {
|
||||
if (shape_vec[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "CSRTensor's shape[2: ] must be equal to value's shape[1: ], but CSRTensor's shape got: "
|
||||
<< shape_value->ToString() << ", "
|
||||
<< "values's shape got: " << values->shape()->ToString() << ".";
|
||||
}
|
||||
shape_size *= LongToSize(shape_vec[i]);
|
||||
}
|
||||
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold "
|
||||
<< values_shp[kIndexZero] << " non-zero values.";
|
||||
}
|
||||
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsDynamicRank(shape_vec) && values_shp.size() + 1 != shape_vec.size()) {
|
||||
MS_EXCEPTION(ValueError) << "Values' dimension should equal to CSRTensor's dimension - 1, but got"
|
||||
<< "Values' dimension: " << values_shp.size()
|
||||
<< ", CSRTensor's dimension: " << shape_vec.size() << ".";
|
||||
|
@ -101,27 +114,6 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
|
||||
}
|
||||
|
||||
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
|
||||
}
|
||||
|
||||
size_t shape_size = 1;
|
||||
for (size_t i = 0; i < shape_vec.size(); ++i) {
|
||||
if (shape_vec[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
|
||||
}
|
||||
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "CSRTensor's shape[2: ] must be equal to value's shape[1: ], but CSRTensor's shape got: "
|
||||
<< shape_value->ToString() << ", "
|
||||
<< "values's shape got: " << values->shape()->ToString() << ".";
|
||||
}
|
||||
shape_size *= LongToSize(shape_vec[i]);
|
||||
}
|
||||
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
|
||||
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
|
||||
<< " non-zero values.";
|
||||
}
|
||||
std::vector<abstract::AbstractBasePtr> element_list{indptr, indices, values, shape};
|
||||
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
|
||||
}
|
||||
|
|
|
@ -95,6 +95,15 @@ void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_nam
|
|||
|
||||
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
|
||||
|
||||
inline void CheckInputShapeEmpty(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (size_t i = 0; i < input_args.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(input_args[i]->BuildShape());
|
||||
if (input_args[i]->BuildShape()->IsDimZero()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << prim_name << "', input " << i << "'s shape should not be empty!";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -63,6 +63,15 @@ def _make_tensor_with_dtype(data, dtype):
|
|||
return Tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _convert_shape(shape):
|
||||
"""Temporary solution to get shape value, will be removed when shape op is supported."""
|
||||
if shape is None:
|
||||
return (-2,)
|
||||
shape = [-1 if i is None else i for i in shape]
|
||||
return tuple(shape)
|
||||
|
||||
|
||||
def is_scalar(tensor):
|
||||
"""Determine whether tensor input is a scalar tensor."""
|
||||
if tensor.size != 1:
|
||||
|
@ -361,12 +370,12 @@ def csr_to_coo(tensor: CSRTensor) -> COOTensor:
|
|||
"""
|
||||
if not isinstance(tensor, CSRTensor):
|
||||
raise_type_error("For functional operator csr_to_coo, input argument must be a CSRTensor.")
|
||||
if len(tensor.shape) != 2:
|
||||
if len(_convert_shape(tensor.shape)) > 2:
|
||||
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
|
||||
shape = tensor.shape
|
||||
indices, values, _ = csr_sparse_matrix_to_sparse_tensor(Tensor(shape, mstype.int32), batch_csr_pointers_empty,
|
||||
tensor.indptr, tensor.indices, tensor.values)
|
||||
return COOTensor(indices, values, shape)
|
||||
return COOTensor(indices, values, _convert_shape(shape))
|
||||
|
||||
|
||||
def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
|
||||
|
@ -401,12 +410,12 @@ def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
|
|||
"""
|
||||
if not isinstance(csr_tensor, CSRTensor):
|
||||
raise_type_error("For functional operator csr_to_dense, input argument must be a CSRTensor.")
|
||||
if len(csr_tensor.shape) != 2:
|
||||
if len(csr_tensor.shape) > 2:
|
||||
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
|
||||
|
||||
shape = csr_tensor.shape
|
||||
shape = _convert_shape(csr_tensor.shape)
|
||||
dense_shape = Tensor(shape, dtype=mstype.int32)
|
||||
batch_pointers = Tensor([0, -1], dtype=mstype.int32)
|
||||
batch_pointers = ops.concat((make_tensor([0]), ops.TensorShape()(csr_tensor.values))).astype("int32")
|
||||
row_pointers = csr_tensor.indptr
|
||||
col_indices = csr_tensor.indices
|
||||
values = csr_tensor.values
|
||||
|
@ -472,11 +481,11 @@ def dense_to_sparse_coo(tensor: Tensor) -> COOTensor:
|
|||
"""
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise_type_error("For functional operator dense_to_sparse_coo, input argument must be a Tensor.")
|
||||
if len(tensor.shape) != 2:
|
||||
if len(_convert_shape(tensor.shape)) > 2:
|
||||
raise_value_error("Currently only support 2-D Tensor when converting to COOTensor.")
|
||||
indices = tensor.nonzero().astype("int32")
|
||||
values = gather_nd(tensor, indices)
|
||||
return COOTensor(indices, values, tensor.shape)
|
||||
return COOTensor(indices, values, _convert_shape(tensor.shape))
|
||||
|
||||
|
||||
def dense_to_sparse_csr(tensor: Tensor) -> CSRTensor:
|
||||
|
@ -518,11 +527,11 @@ def dense_to_sparse_csr(tensor: Tensor) -> CSRTensor:
|
|||
"""
|
||||
if not isinstance(tensor, Tensor):
|
||||
raise_type_error("For functional operator dense_to_sparse_csr, input argument must be a Tensor.")
|
||||
if len(tensor.shape) > 2:
|
||||
if len(_convert_shape(tensor.shape)) > 2:
|
||||
raise_value_error("Currently only support 2-D Tensor when converting to CSRTensor.")
|
||||
indices = tensor.nonzero().astype("int32")
|
||||
_, _, indptr, indices, values = dense_to_csr(tensor, indices)
|
||||
return CSRTensor(indptr, indices, values, tensor.shape)
|
||||
return CSRTensor(indptr, indices, values, _convert_shape(tensor.shape))
|
||||
|
||||
|
||||
def make_sparse_tensor(indices, values, dense_shape):
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""smoke tests for sparse dynamic shape operations"""
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, nn, CSRTensor, ops
|
||||
from mindspore.common import dtype as mstype
|
||||
from tests.st.ops.dynamic_shape.grad.test_grad_of_dynamic import TestDynamicGrad
|
||||
|
||||
|
||||
class NetDenseToCSR(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
csr = x.to_csr()
|
||||
return csr.indptr, csr.indices, csr.values
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_dense_to_csr():
|
||||
"""
|
||||
Feature: Test tensor.to_csr in Graph and PyNative.
|
||||
Description: Test tensor.to_csr in dynamic rank and dynamic shape.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_dynamic = TestDynamicGrad(NetDenseToCSR())
|
||||
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32)
|
||||
test_dynamic.test_dynamic_grad_net((x))
|
||||
test_dynamic.test_dynamic_grad_net((x), is_dynamic_rank=True)
|
||||
|
||||
|
||||
class NetDenseToCOO(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
coo = x.to_coo()
|
||||
return coo.indices, coo.values
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_dense_to_coo():
|
||||
"""
|
||||
Feature: Test tensor.to_coo in Graph and PyNative.
|
||||
Description: Test tensor.to_coo in dynamic rank and dynamic shape.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_dynamic = TestDynamicGrad(NetDenseToCOO())
|
||||
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32)
|
||||
test_dynamic.test_dynamic_grad_net((x))
|
||||
test_dynamic.test_dynamic_grad_net((x), is_dynamic_rank=True)
|
||||
|
||||
|
||||
class NetCSRToCOO(nn.Cell):
|
||||
def construct(self, indptr, indices, values, shape):
|
||||
csr = CSRTensor(indptr, indices, values, shape)
|
||||
coo = ops.csr_to_coo(csr)
|
||||
return coo.indices, coo.values
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_to_coo():
|
||||
"""
|
||||
Feature: Test ops.csr_to_coo in Graph and PyNative.
|
||||
Description: Test ops.csr_to_coo in dynamic rank and dynamic shape.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_dynamic = TestDynamicGrad(NetCSRToCOO())
|
||||
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32).to_csr()
|
||||
args = (x.indptr, x.indices, x.values, x.shape)
|
||||
test_dynamic.test_dynamic_grad_net(args)
|
||||
test_dynamic.test_dynamic_grad_net(args, is_dynamic_rank=True)
|
||||
|
||||
|
||||
class NetCSRToDense(nn.Cell):
|
||||
|
||||
def construct(self, indptr, indices, values, dense_shape):
|
||||
x = CSRTensor(indptr, indices, values, dense_shape)
|
||||
return x.to_dense()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_to_dense_dshape():
|
||||
"""
|
||||
Feature: Test csr_tensor.to_dense in Graph and PyNative.
|
||||
Description: Test csr_tensor.to_dense in dynamic shape.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_dynamic = TestDynamicGrad(NetCSRToDense())
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(1, 7), dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
x = (indptr, indices, values, dense_shape)
|
||||
test_dynamic.test_dynamic_grad_net(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_csr_to_dense_drank():
|
||||
"""
|
||||
Feature: Test csr_tensor.to_dense in Graph and PyNative.
|
||||
Description: Test csr_tensor.to_dense in dynamic rank.
|
||||
Expectation: Success.
|
||||
"""
|
||||
test_dynamic = TestDynamicGrad(NetCSRToDense())
|
||||
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
|
||||
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
|
||||
values = Tensor(np.arange(1, 7), dtype=mstype.float32)
|
||||
dense_shape = (3, 4)
|
||||
x = (indptr, indices, values, dense_shape)
|
||||
test_dynamic.test_dynamic_grad_net(x, is_dynamic_rank=True)
|
Loading…
Reference in New Issue