!46741 support dense to sparse coo dynamic rank

Merge pull request !46741 from 杨林枫/coo_to_sparse_drank
This commit is contained in:
i-robot 2022-12-27 06:44:13 +00:00 committed by Gitee
commit ab4d555d43
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 231 additions and 69 deletions

View File

@ -60,6 +60,8 @@ int DenseToCSRSparseMatrixCpuKernelMod::Resize(const BaseOperatorPtr &base_opera
batch_size_ = (rank_ == kDefaultRank) ? kOne : dense_shape[kZero];
num_rows_ = (rank_ == kDefaultRank) ? dense_shape[kZero] : dense_shape[kOne];
num_cols_ = (rank_ == kDefaultRank) ? dense_shape[kOne] : dense_shape[kTwo];
total_ele_ = (rank_ == kDefaultRank) ? dense_shape[kZero] * dense_shape[kOne]
: dense_shape[kZero] * dense_shape[kOne] * dense_shape[kTwo];
return KRET_OK;
}
@ -116,6 +118,13 @@ bool DenseToCSRSparseMatrixCpuKernelMod::Launch(const std::vector<kernel::Addres
return true;
}
inline void CheckIndicesInRange(const size_t total_ele, const size_t idx, const std::string &name) {
if (idx >= total_ele) {
MS_LOG(EXCEPTION) << "For '" << name << "', flattened index must in range: [0, " << total_ele
<< "), but got: " << idx << ".";
}
}
template <typename indiceT, typename valueT>
void DenseToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) const {
@ -136,12 +145,14 @@ void DenseToCSRSparseMatrixCpuKernelMod::LaunchKernel(const std::vector<AddressP
}
for (size_t i = kZero; i < total_nnz_; i++) {
if (rank_ == kDefaultRank) {
auto cur_idx = indices_ptr[i * rank_] * indiceT(num_cols_) + indices_ptr[i * rank_ + kOne];
y_values_ptr[i] = dense_input_ptr[LongToSize(cur_idx)];
auto cur_idx = LongToSize(indices_ptr[i * rank_] * indiceT(num_cols_) + indices_ptr[i * rank_ + kOne]);
CheckIndicesInRange(total_ele_, cur_idx, kernel_name_);
y_values_ptr[i] = dense_input_ptr[cur_idx];
} else {
auto cur_idx = indices_ptr[i * rank_] * indiceT(num_rows_) * indiceT(num_cols_) +
indices_ptr[i * rank_ + kOne] * indiceT(num_cols_) + indices_ptr[i * rank_ + kTwo];
y_values_ptr[i] = dense_input_ptr[LongToSize(cur_idx)];
auto cur_idx = LongToSize(indices_ptr[i * rank_] * indiceT(num_rows_) * indiceT(num_cols_) +
indices_ptr[i * rank_ + kOne] * indiceT(num_cols_) + indices_ptr[i * rank_ + kTwo]);
CheckIndicesInRange(total_ele_, cur_idx, kernel_name_);
y_values_ptr[i] = dense_input_ptr[cur_idx];
}
}
for (size_t i = kZero; i < batch_size_ * (num_rows_ + kOne); i++) {

View File

@ -52,6 +52,7 @@ class DenseToCSRSparseMatrixCpuKernelMod : public NativeCpuKernelMod {
size_t num_rows_{0};
size_t num_cols_{0};
size_t total_nnz_{0};
size_t total_ele_{0};
TypeId values_type_{kTypeUnknown};
TypeId indices_type_{kTypeUnknown};
};

View File

@ -34,6 +34,7 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
const int64_t kOne = 1;
const int64_t kDefalutRank = 2;
const int64_t kBatchRank = 3;
CheckInputShapeEmpty(primitive->name(), input_args);
auto d_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto b_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto r_ptrs_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
@ -51,8 +52,9 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
<< values_shape.size() << ".";
}
// Dynamic Rank
if (IsDynamicRank(d_shape_shape) || IsDynamicRank(b_ptrs_shape) || IsDynamicRank(r_ptrs_shape) ||
IsDynamicRank(c_ind_shape) || IsDynamicRank(values_shape)) {
std::vector<ShapeVector> tensor_shapes{d_shape_shape, c_ind_shape, values_shape, r_ptrs_shape, b_ptrs_shape};
if (std::any_of(tensor_shapes.cbegin(), tensor_shapes.cend(),
[](const ShapeVector shp) { return IsDynamicRank(shp); })) {
ShapeVector dense_shape = {-2};
return std::make_shared<abstract::Shape>(dense_shape);
}
@ -70,7 +72,7 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
return std::make_shared<abstract::Shape>(dense_shape);
}
// Static Shape
if (values_shape[kZero] != c_ind_shape[kZero]) {
if (!IsDynamic(values_shape) && !IsDynamic(c_ind_shape) && values_shape[kZero] != c_ind_shape[kZero]) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'col_indices' and 'values' "
<< "should have the same length.";
}
@ -95,7 +97,8 @@ abstract::ShapePtr CSRSparseMatrixToDenseInferShape(const PrimitivePtr &primitiv
if (rank == kBatchRank) {
batch_size = d_shape_value_ptr_tensor[kZero], row_num = d_shape_value_ptr_tensor[kOne];
}
if (b_ptrs_shape[kZero] != (batch_size + kOne) || r_ptrs_shape[kZero] != batch_size * (row_num + kOne)) {
if (!IsDynamic(b_ptrs_shape) && !IsDynamic(r_ptrs_shape) &&
(b_ptrs_shape[kZero] != (batch_size + kOne) || r_ptrs_shape[kZero] != batch_size * (row_num + kOne))) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', batch size of the input is " << batch_size
<< ", row numbers of the input is " << row_num << ", so shape of 'x_batch_pointers' "
<< "should be (" << batch_size + kOne << "), but got (" << b_ptrs_shape[kZero] << ")"

View File

@ -59,6 +59,7 @@ abstract::TupleShapePtr CSRSparseMatrixToSparseTensorInferShape(const PrimitiveP
const int64_t kOne = 1;
const int64_t kDefalutRank = 2;
const int64_t kBatchRank = 3;
CheckInputShapeEmpty(primitive->name(), input_args);
std::vector<int64_t> x_dense_shape_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto prim_name = primitive->name();

View File

@ -29,6 +29,7 @@ namespace ops {
namespace {
abstract::TupleShapePtr DenseToCSRSparseMatrixInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
CheckInputShapeEmpty(primitive->name(), input_args);
auto dense_input_shape_ptr = input_args[kInputIndex0]->BuildShape();
auto dense_input_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(dense_input_shape_ptr);
auto dense_input_shape = dense_input_shape_map[kShape];

View File

@ -28,6 +28,7 @@
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/shape_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
@ -62,15 +63,17 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
}
// Convert dense_shape from tuple to shapevector(dense_shape_vec)
auto dense_shape_vec = GetShapeValue(primitive, dense_shape);
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
ShapeVector dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
if (!IsDynamic(dense_shape_vec)) {
MS_EXCEPTION_IF_NULL(dense_shape_value);
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem <= 0) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
<< dense_shape_value->ToString();
}
}
}
if (IsDynamic(indices_shp) || IsDynamic(values_shp)) {
MS_LOG(DEBUG) << "Dynamic shape in MakeCOOTensor's inputs! Ignore shape check.";
@ -78,6 +81,9 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
}
CheckSparseShape(indices_shp.size(), kSizeTwo, "Indices");
CheckSparseShape(values_shp.size(), kSizeOne, "Values");
if (indices_shp[kIndexZero] != values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "For COOTensor, `indices.shape[" << kIndexZero << "]` must be equal to `values.shape["
<< kIndexZero << "]`, but got `indices.shape[" << kIndexZero
@ -90,17 +96,11 @@ AbstractBasePtr MakeCOOTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
<< indices_shp[kIndexOne];
}
if (LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
if (!IsDynamicRank(dense_shape_vec) && LongToSize(indices_shp[kIndexOne]) != dense_shape_vec.size()) {
MS_EXCEPTION(TypeError) << "For COOTensor, `indices.shape[" << indices_shp << "]` must be equal to the second "
<< "dimension of `indices`: " << dense_shape_vec.size() << " but got "
<< indices_shp[kIndexOne];
}
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem <= 0) {
MS_EXCEPTION(TypeError) << "For COOTensor, the element of `shape` must be positive, but got "
<< dense_shape_value->ToString();
}
}
AbstractBasePtrList element_list{indices, values, dense_shape};
return std::make_shared<abstract::AbstractCOOTensor>(element_list);
}

View File

@ -28,6 +28,7 @@
#include "ops/primitive_c.h"
#include "utils/anf_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/shape_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
@ -55,6 +56,7 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
auto indices_shp = indices->shape()->shape();
CheckSparseShape(indices_shp.size(), kSizeOne, "Indices");
auto values_shp = values->shape()->shape();
for (const auto &elem_type : shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of shape must be Int, but got " << elem_type->ToString();
@ -62,21 +64,8 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
}
// convert shape from tuple to shapevector(shape_vec)
auto shape_vec = GetShapeValue(primitive, shape);
auto shape_value = shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(shape_value);
auto shp = shape_value->value();
ShapeVector shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(shape_vec), [](const ValuePtr &e) -> int64_t {
auto elem = GetValue<int64_t>(e);
return elem;
});
auto values_shp = values->shape()->shape();
if (values_shp.size() + 1 != shape_vec.size()) {
MS_EXCEPTION(ValueError) << "Values' dimension should equal to CSRTensor's dimension - 1, but got"
<< "Values' dimension: " << values_shp.size()
<< ", CSRTensor's dimension: " << shape_vec.size() << ".";
}
if (IsShapeEmpty(indptr_shp) && IsShapeEmpty(indices_shp) && IsShapeEmpty(values_shp)) {
MS_LOG(DEBUG) << "Constructing empty CSRTensor! Ignore further shape check.";
@ -90,7 +79,31 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
}
if (values_shp.size() + 1 != shape_vec.size()) {
if (!IsDynamic(shape_vec)) {
MS_EXCEPTION_IF_NULL(shape_value);
size_t shape_size = 1;
for (size_t i = 0; i < shape_vec.size(); ++i) {
if (shape_vec[i] <= 0) {
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
}
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
MS_EXCEPTION(ValueError)
<< "CSRTensor's shape[2: ] must be equal to value's shape[1: ], but CSRTensor's shape got: "
<< shape_value->ToString() << ", "
<< "values's shape got: " << values->shape()->ToString() << ".";
}
shape_size *= LongToSize(shape_vec[i]);
}
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold "
<< values_shp[kIndexZero] << " non-zero values.";
}
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
}
}
if (!IsDynamicRank(shape_vec) && values_shp.size() + 1 != shape_vec.size()) {
MS_EXCEPTION(ValueError) << "Values' dimension should equal to CSRTensor's dimension - 1, but got"
<< "Values' dimension: " << values_shp.size()
<< ", CSRTensor's dimension: " << shape_vec.size() << ".";
@ -101,27 +114,6 @@ AbstractBasePtr MakeCSRTensorInfer(const abstract::AnalysisEnginePtr &, const Pr
<< values_shp[kIndexZero] << ", indices length " << indices_shp[kIndexZero];
}
if (shape_vec[kIndexZero] + 1 != indptr_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Indptr must have length (1 + shape[0]), but got: " << indptr_shp[kIndexZero];
}
size_t shape_size = 1;
for (size_t i = 0; i < shape_vec.size(); ++i) {
if (shape_vec[i] <= 0) {
MS_EXCEPTION(ValueError) << "The element of shape must be positive, but got " << shape_value->ToString();
}
if ((i > 1) && (shape_vec[i] != values_shp[i - 1])) {
MS_EXCEPTION(ValueError)
<< "CSRTensor's shape[2: ] must be equal to value's shape[1: ], but CSRTensor's shape got: "
<< shape_value->ToString() << ", "
<< "values's shape got: " << values->shape()->ToString() << ".";
}
shape_size *= LongToSize(shape_vec[i]);
}
if (static_cast<int64_t>(shape_size) < values_shp[kIndexZero]) {
MS_EXCEPTION(ValueError) << "Shape total size: " << shape_size << " is too small to hold " << values_shp[kIndexZero]
<< " non-zero values.";
}
std::vector<abstract::AbstractBasePtr> element_list{indptr, indices, values, shape};
return std::make_shared<abstract::AbstractCSRTensor>(element_list);
}

View File

@ -95,6 +95,15 @@ void CheckSparseIndicesDtype(const TypePtr data_type, const std::string &arg_nam
void CheckSparseIndicesDtypeInt32(const TypePtr data_type, const std::string &arg_name);
inline void CheckInputShapeEmpty(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args) {
for (size_t i = 0; i < input_args.size(); ++i) {
MS_EXCEPTION_IF_NULL(input_args[i]->BuildShape());
if (input_args[i]->BuildShape()->IsDimZero()) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', input " << i << "'s shape should not be empty!";
}
}
}
ShapeVector ConvertToShapeVector(const abstract::AbstractTuplePtr &shape);
template <typename T>

View File

@ -63,6 +63,15 @@ def _make_tensor_with_dtype(data, dtype):
return Tensor(data, dtype=dtype)
@constexpr
def _convert_shape(shape):
"""Temporary solution to get shape value, will be removed when shape op is supported."""
if shape is None:
return (-2,)
shape = [-1 if i is None else i for i in shape]
return tuple(shape)
def is_scalar(tensor):
"""Determine whether tensor input is a scalar tensor."""
if tensor.size != 1:
@ -361,12 +370,12 @@ def csr_to_coo(tensor: CSRTensor) -> COOTensor:
"""
if not isinstance(tensor, CSRTensor):
raise_type_error("For functional operator csr_to_coo, input argument must be a CSRTensor.")
if len(tensor.shape) != 2:
if len(_convert_shape(tensor.shape)) > 2:
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
shape = tensor.shape
indices, values, _ = csr_sparse_matrix_to_sparse_tensor(Tensor(shape, mstype.int32), batch_csr_pointers_empty,
tensor.indptr, tensor.indices, tensor.values)
return COOTensor(indices, values, shape)
return COOTensor(indices, values, _convert_shape(shape))
def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
@ -401,12 +410,12 @@ def csr_to_dense(csr_tensor: CSRTensor) -> Tensor:
"""
if not isinstance(csr_tensor, CSRTensor):
raise_type_error("For functional operator csr_to_dense, input argument must be a CSRTensor.")
if len(csr_tensor.shape) != 2:
if len(csr_tensor.shape) > 2:
raise_value_error("Currently only support 2-D CSRTensor when converting to COOTensor.")
shape = csr_tensor.shape
shape = _convert_shape(csr_tensor.shape)
dense_shape = Tensor(shape, dtype=mstype.int32)
batch_pointers = Tensor([0, -1], dtype=mstype.int32)
batch_pointers = ops.concat((make_tensor([0]), ops.TensorShape()(csr_tensor.values))).astype("int32")
row_pointers = csr_tensor.indptr
col_indices = csr_tensor.indices
values = csr_tensor.values
@ -472,11 +481,11 @@ def dense_to_sparse_coo(tensor: Tensor) -> COOTensor:
"""
if not isinstance(tensor, Tensor):
raise_type_error("For functional operator dense_to_sparse_coo, input argument must be a Tensor.")
if len(tensor.shape) != 2:
if len(_convert_shape(tensor.shape)) > 2:
raise_value_error("Currently only support 2-D Tensor when converting to COOTensor.")
indices = tensor.nonzero().astype("int32")
values = gather_nd(tensor, indices)
return COOTensor(indices, values, tensor.shape)
return COOTensor(indices, values, _convert_shape(tensor.shape))
def dense_to_sparse_csr(tensor: Tensor) -> CSRTensor:
@ -518,11 +527,11 @@ def dense_to_sparse_csr(tensor: Tensor) -> CSRTensor:
"""
if not isinstance(tensor, Tensor):
raise_type_error("For functional operator dense_to_sparse_csr, input argument must be a Tensor.")
if len(tensor.shape) > 2:
if len(_convert_shape(tensor.shape)) > 2:
raise_value_error("Currently only support 2-D Tensor when converting to CSRTensor.")
indices = tensor.nonzero().astype("int32")
_, _, indptr, indices, values = dense_to_csr(tensor, indices)
return CSRTensor(indptr, indices, values, tensor.shape)
return CSRTensor(indptr, indices, values, _convert_shape(tensor.shape))
def make_sparse_tensor(indices, values, dense_shape):

View File

@ -0,0 +1,135 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""smoke tests for sparse dynamic shape operations"""
import pytest
import numpy as np
from mindspore import Tensor, nn, CSRTensor, ops
from mindspore.common import dtype as mstype
from tests.st.ops.dynamic_shape.grad.test_grad_of_dynamic import TestDynamicGrad
class NetDenseToCSR(nn.Cell):
def construct(self, x):
csr = x.to_csr()
return csr.indptr, csr.indices, csr.values
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dense_to_csr():
"""
Feature: Test tensor.to_csr in Graph and PyNative.
Description: Test tensor.to_csr in dynamic rank and dynamic shape.
Expectation: Success.
"""
test_dynamic = TestDynamicGrad(NetDenseToCSR())
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32)
test_dynamic.test_dynamic_grad_net((x))
test_dynamic.test_dynamic_grad_net((x), is_dynamic_rank=True)
class NetDenseToCOO(nn.Cell):
def construct(self, x):
coo = x.to_coo()
return coo.indices, coo.values
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dense_to_coo():
"""
Feature: Test tensor.to_coo in Graph and PyNative.
Description: Test tensor.to_coo in dynamic rank and dynamic shape.
Expectation: Success.
"""
test_dynamic = TestDynamicGrad(NetDenseToCOO())
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32)
test_dynamic.test_dynamic_grad_net((x))
test_dynamic.test_dynamic_grad_net((x), is_dynamic_rank=True)
class NetCSRToCOO(nn.Cell):
def construct(self, indptr, indices, values, shape):
csr = CSRTensor(indptr, indices, values, shape)
coo = ops.csr_to_coo(csr)
return coo.indices, coo.values
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_to_coo():
"""
Feature: Test ops.csr_to_coo in Graph and PyNative.
Description: Test ops.csr_to_coo in dynamic rank and dynamic shape.
Expectation: Success.
"""
test_dynamic = TestDynamicGrad(NetCSRToCOO())
x = Tensor(np.array([[2, 0, -1], [0, 0, 1]]), mstype.float32).to_csr()
args = (x.indptr, x.indices, x.values, x.shape)
test_dynamic.test_dynamic_grad_net(args)
test_dynamic.test_dynamic_grad_net(args, is_dynamic_rank=True)
class NetCSRToDense(nn.Cell):
def construct(self, indptr, indices, values, dense_shape):
x = CSRTensor(indptr, indices, values, dense_shape)
return x.to_dense()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_to_dense_dshape():
"""
Feature: Test csr_tensor.to_dense in Graph and PyNative.
Description: Test csr_tensor.to_dense in dynamic shape.
Expectation: Success.
"""
test_dynamic = TestDynamicGrad(NetCSRToDense())
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(1, 7), dtype=mstype.float32)
dense_shape = (3, 4)
x = (indptr, indices, values, dense_shape)
test_dynamic.test_dynamic_grad_net(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_csr_to_dense_drank():
"""
Feature: Test csr_tensor.to_dense in Graph and PyNative.
Description: Test csr_tensor.to_dense in dynamic rank.
Expectation: Success.
"""
test_dynamic = TestDynamicGrad(NetCSRToDense())
indptr = Tensor([0, 1, 4, 6], dtype=mstype.int32)
indices = Tensor([3, 0, 1, 2, 1, 3], dtype=mstype.int32)
values = Tensor(np.arange(1, 7), dtype=mstype.float32)
dense_shape = (3, 4)
x = (indptr, indices, values, dense_shape)
test_dynamic.test_dynamic_grad_net(x, is_dynamic_rank=True)