merge all commits from sparse_tensor_dense_ds

This commit is contained in:
hw_hz 2022-10-18 21:39:38 +08:00
parent a60615e409
commit 613861da9c
8 changed files with 271 additions and 128 deletions

View File

@ -32,17 +32,41 @@ using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
} // namespace
void SparseTensorDenseAddCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0);
auto values_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex1);
auto shape_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex2);
auto x2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex3);
if (AnfAlgo::IsShapesDynamic({values_shape, indices_shape, shape_shape, x2_shape})) {
return;
bool SparseTensorDenseAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseTensorDenseAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseTensorDenseAddOutputsNum, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "SparseTensorDenseAdd does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return true;
}
int SparseTensorDenseAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto indices_shape = inputs.at(kIndex0)->GetShapeVector();
auto values_shape = inputs.at(kIndex1)->GetShapeVector();
auto shape_shape = inputs.at(kIndex2)->GetShapeVector();
auto x2_shape = inputs.at(kIndex3)->GetShapeVector();
values_size_ = static_cast<size_t>(values_shape[0]);
output_shape_ = outputs.at(kIndex0)->GetShapeVector();
x2_shape_ = x2_shape;
size_t x1_rank = static_cast<size_t>(shape_shape[0]);
size_t x2_rank = x2_shape_.size();
if (IsDynamic(values_shape) || IsDynamic(indices_shape) || IsDynamic(shape_shape) || IsDynamic(x2_shape)) {
return KRET_OK;
}
if (indices_shape.size() != kIndicesShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it requires 'x1_indices' must be a " << kIndicesShapeSize
<< "-D Tensor, but got " << indices_shape.size() << "-D";
@ -53,30 +77,17 @@ void SparseTensorDenseAddCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
<< "must be equal to the first dimension length of 'indices', but got 'x1_values' shape: "
<< Vector2Str(values_shape) << " and 'x1_indices' shape: " << Vector2Str(indices_shape);
}
x2_shape_ = x2_shape;
size_t x1_rank = static_cast<size_t>(shape_shape[0]);
size_t x2_rank = x2_shape_.size();
if (x1_rank != x2_rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', x1 and x2 must have same ranks, but got 'x1' shape: " << Vector2Str(shape_shape)
<< "and 'x2' shape: " << Vector2Str(x2_shape_);
}
values_size_ = static_cast<size_t>(values_shape[0]);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "SparseTensorDenseAdd does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return KRET_OK;
}
template <typename I, typename T>
bool SparseTensorDenseAddCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseTensorDenseAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseTensorDenseAddOutputsNum, kernel_name_);
if (outputs[0]->size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', output memory size must be greater than 0, but got 0.";
return true;

View File

@ -21,17 +21,23 @@
#include <unordered_map>
#include <vector>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseTensorDenseAddCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SparseTensorDenseAddCpuKernelMod : public NativeCpuKernelMod {
public:
SparseTensorDenseAddCpuKernelMod() = default;
~SparseTensorDenseAddCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
@ -51,8 +57,8 @@ class SparseTensorDenseAddCpuKernelMod : public DeprecatedNativeCpuKernelMod {
std::function<bool(SparseTensorDenseAddCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, SparseTensorDenseAddFunc>> func_list_;
std::string kernel_name_;
SparseTensorDenseAddFunc kernel_func_;
ShapeVector x2_shape_;
ShapeVector output_shape_;
size_t values_size_{0};

View File

@ -32,17 +32,37 @@ constexpr size_t kIndicesSizeNum = 2;
constexpr size_t kIndices2rdDimNum = 2;
} // namespace
void SparseTensorDenseMatmulCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
adj_st_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJ_ST);
adj_dt_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJ_dT);
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, INDICES);
auto values_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, VALUES);
auto output_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
auto b_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, static_cast<size_t>(DENSE));
if (AnfAlgo::IsShapesDynamic({values_shape, indices_shape, output_shape, b_shape})) {
return;
bool SparseTensorDenseMatmulCpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto prim = base_operator->GetPrim();
adj_st_ = GetValue<bool>(prim->GetAttr(ADJ_ST));
adj_dt_ = GetValue<bool>(prim->GetAttr(ADJ_dT));
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "SparseTensorDenseMatmul does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return true;
}
int SparseTensorDenseMatmulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto indices_shape = inputs.at(kIndex0)->GetShapeVector();
auto values_shape = inputs.at(kIndex1)->GetShapeVector();
auto b_shape = inputs.at(kIndex3)->GetShapeVector();
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
std::vector<std::vector<int64_t>> all_shapes = {indices_shape, values_shape, b_shape, output_shape};
bool is_dynamic = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamic);
if (is_dynamic) {
return KRET_OK;
}
if (indices_shape.size() != kIndicesSizeNum && indices_shape[1] != kIndices2rdDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -67,13 +87,7 @@ void SparseTensorDenseMatmulCpuKernelMod::InitKernel(const CNodePtr &kernel_node
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output must be "
<< kSparseTensorDenseMatmulOutputShapeSize << "-D, but got " << output_shape_.size() << "-D";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "SparseTensorDenseMatmul does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
return KRET_OK;
}
template <typename I, typename T>

View File

@ -19,17 +19,23 @@
#include <vector>
#include <utility>
#include <map>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseTensorDenseMatmulCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class SparseTensorDenseMatmulCpuKernelMod : public NativeCpuKernelMod {
public:
SparseTensorDenseMatmulCpuKernelMod() = default;
~SparseTensorDenseMatmulCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
@ -46,7 +52,7 @@ class SparseTensorDenseMatmulCpuKernelMod : public DeprecatedNativeCpuKernelMod
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, SparseTensorDenseMatmulFunc>> func_list_;
SparseTensorDenseMatmulFunc kernel_func_;
std::string kernel_name_;
std::vector<size_t> output_shape_;
std::vector<size_t> b_shape_;
size_t output_size_{0};

View File

@ -132,6 +132,7 @@ PrimShapeDependMap &GetHostDependsMap() {
static const auto &kTraceGrad = prim::kPrimTraceGrad->name();
static const auto &kSetSize = prim::kPrimSetSize->name();
static const auto &kDynamicStitch = prim::kPrimDynamicStitch->name();
static const auto &kSparseTensorDenseMatmul = prim::kPrimSparseTensorDenseMatmul->name();
// Common host depends.
static PrimShapeDependMap host_depends{{prim::kPrimArgMax->name(), ShapeSet{1}},
{prim::kPrimArgmin->name(), ShapeSet{1}},
@ -204,7 +205,8 @@ PrimShapeDependMap &GetHostDependsMap() {
{prim::kPrimCumSum->name(), ShapeSet{1}},
{kAdaptiveMaxPool3DGrad, ShapeSet{1}},
{kSetSize, ShapeSet{2}},
{kDynamicStitch, ShapeSet{0}}};
{kDynamicStitch, ShapeSet{0}},
{kSparseTensorDenseMatmul, ShapeSet{2}}};
return host_depends;
}

View File

@ -28,6 +28,8 @@
namespace mindspore {
namespace ops {
namespace {
const int kDimensionOne = 1;
const int kDimensionTwo = 2;
void CheckShapeRank(const size_t cur_rank, const size_t expected_rank, const std::string &op_name,
const std::string &arg_name) {
if (cur_rank != expected_rank) {
@ -47,6 +49,7 @@ bool checkType(std::string name, TypePtr dtype, std::set<TypePtr> vtypes, const
}
return true;
}
bool checkContainer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args, std::string *info) {
const int kTwo = 2;
const int kOne = 1;
@ -70,92 +73,116 @@ bool checkContainer(const PrimitivePtr &primitive, const std::vector<AbstractBas
}
return true;
}
void SparseTensorDenseMatmulCheckShape(const std::string &prim_name, const bool &is_dynamic_rank,
const bool &is_dynamic, const ShapeVector &indices_shape,
const ShapeVector &values_shape, const ShapeVector &shape_shape,
const ShapeVector &x2_shape) {
if (!is_dynamic_rank) {
CheckShapeRank(indices_shape.size(), kDimensionTwo, prim_name, "indices");
CheckShapeRank(values_shape.size(), kDimensionOne, prim_name, "values");
CheckShapeRank(shape_shape.size(), kDimensionOne, prim_name, "sparse_shape");
CheckShapeRank(x2_shape.size(), kDimensionTwo, prim_name, "the shape of input dense");
}
if (!is_dynamic) {
if (indices_shape[1] != kDimensionTwo) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', the 2nd dimension of indices "
<< "should be 2, but got " << indices_shape[1] << ".";
}
if (values_shape[0] != indices_shape[0]) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', the input values' length "
<< "is different from indices' first dimension";
}
if (shape_shape[0] != kDimensionTwo) {
MS_LOG(EXCEPTION) << "For '" << prim_name << "', the 1st dimension of sparse_shape "
<< "should be 2, but got " << shape_shape[0] << ".";
}
}
}
void SparseTensorDenseMatmulCheckShapeSetShape(const std::string &prim_name, int64_t *shape_ptr,
const ShapeVector &shape_shape, const AbstractBasePtr &x1_shape) {
if (x1_shape->isa<abstract::AbstractTensor>() && x1_shape->BuildValue()->isa<tensor::Tensor>()) {
auto a_shape = x1_shape->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(a_shape);
auto a_shape_value = a_shape->BuildValue();
MS_EXCEPTION_IF_NULL(a_shape_value);
auto a_shape_tensor = a_shape_value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(a_shape_tensor);
auto a_shape_size = a_shape_tensor->DataSize();
auto expect_size = std::accumulate(shape_shape.begin(), shape_shape.end(), 1, std::multiplies{});
MS_EXCEPTION_IF_CHECK_FAIL(a_shape_size == LongToSize(expect_size),
"For '" + prim_name + "', something unexpected happened.");
auto a_shape_ptr = a_shape_tensor->data_c();
for (size_t i = 0; i < kDimensionTwo; ++i) {
if (a_shape_tensor->Dtype() == kInt32) {
shape_ptr[i] = IntToLong(*(reinterpret_cast<int *>(a_shape_ptr) + i));
} else {
shape_ptr[i] = *(reinterpret_cast<int64_t *>(a_shape_ptr) + i);
}
}
} else if (IsIdentidityOrSubclass(x1_shape->BuildType(), kTuple)) {
auto value_tuple = GetValue<std::vector<int64_t>>(x1_shape->BuildValue());
for (size_t i = 0; i < kDimensionTwo; ++i) {
shape_ptr[i] = value_tuple[i];
}
}
}
abstract::ShapePtr SparseTensorDenseMatmulInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
auto x2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->BuildShape())[kShape];
const int kDimensionTwo = 2;
const int kDimensionOne = 1;
auto x1_shape = input_args[2];
auto x1_shape_value = x1_shape->BuildValue();
std::string info;
if (!checkContainer(primitive, input_args, &info)) {
MS_EXCEPTION(TypeError) << "For " << primitive->name() << info;
MS_EXCEPTION(TypeError) << "For " << prim_name << info;
}
if (x1_shape->isa<abstract::AbstractTuple>()) {
int64_t shape_len = static_cast<int64_t>(GetValue<std::vector<int64_t>>(x1_shape_value).size());
shape_shape = std::vector<int64_t>{shape_len};
}
std::vector<std::vector<int64_t>> all_shapes = {indices_shape, values_shape, shape_shape, x2_shape};
bool is_dynamic = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamic);
bool is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank);
if (!is_dynamic_rank) {
CheckShapeRank(indices_shape.size(), kDimensionTwo, primitive->name(), "indices");
CheckShapeRank(values_shape.size(), kDimensionOne, primitive->name(), "values");
CheckShapeRank(shape_shape.size(), kDimensionOne, primitive->name(), "sparse_shape");
CheckShapeRank(x2_shape.size(), kDimensionTwo, primitive->name(), "the shape of input dense");
}
SparseTensorDenseMatmulCheckShape(prim_name, is_dynamic_rank, is_dynamic, indices_shape, values_shape, shape_shape,
x2_shape);
if (!is_dynamic) {
if (indices_shape[1] != kDimensionTwo) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the 2nd dimension of indices "
<< "should be 2, but got " << indices_shape[1] << ".";
}
if (values_shape[0] != indices_shape[0]) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the input values' length "
<< "is different from indices' first dimension";
}
if (shape_shape[0] != kDimensionTwo) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the 1st dimension of sparse_shape "
<< "should be 2, but got " << shape_shape[0] << ".";
}
if (x1_shape_value->isa<AnyValue>() || x1_shape_value->isa<None>()) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', the input sparse_shape "
<< "should be constant.";
}
}
auto adjoint_a = primitive->GetAttr("adjoint_st");
auto adjoint_b = primitive->GetAttr("adjoint_dt");
bool adjoint_av = GetValue<bool>(adjoint_a);
bool adjoint_bv = GetValue<bool>(adjoint_b);
int64_t x1_row = -1, x1_col = -1;
int64_t x2_row = -1, x2_col = -1;
if (is_dynamic_rank) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{x1_row, x2_col});
}
ShapeVector shape{-1, -1};
SparseTensorDenseMatmulCheckShapeSetShape(prim_name, shape.data(), shape_shape, x1_shape);
if (shape.size() == kDimensionTwo) {
x1_row = shape[0];
x1_col = shape[1];
}
if (x2_shape.size() == kDimensionTwo) {
x2_row = x2_shape[0];
x2_col = x2_shape[1];
}
if (x1_shape->isa<abstract::AbstractTuple>()) {
auto temp = GetValue<std::vector<int64_t>>(x1_shape_value);
if (temp.size() == kDimensionTwo) {
x1_row = temp[0];
x1_col = temp[1];
}
if (adjoint_av) std::swap(x1_row, x1_col);
if (adjoint_bv) std::swap(x2_row, x2_col);
int64_t y_row = x1_row, y_col = x2_col;
std::vector<int64_t> y_shape{y_row, y_col};
return std::make_shared<abstract::Shape>(y_shape);
}
if (x1_shape_value->isa<tensor::Tensor>()) {
auto shape = CheckAndConvertUtils::CheckTensorIntValue("x1_shape", x1_shape_value, primitive->name());
if (shape.size() == kDimensionTwo) {
x1_row = shape[0];
x1_col = shape[1];
}
}
if (adjoint_av) std::swap(x1_row, x1_col);
if (adjoint_bv) std::swap(x2_row, x2_col);
int64_t y_row = x1_row, y_col = x2_col;
std::vector<int64_t> y_shape{y_row, y_col};
return std::make_shared<abstract::Shape>(y_shape);
}
TypePtr SparseTensorDenseMatmulInferType(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr arg) { return arg == nullptr; })) {

View File

@ -18,6 +18,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import composite as C
@ -25,6 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class SparseDenseMatmulNet(nn.Cell):
def __init__(self, adjoint_st=False, adjoint_dt=False):
super(SparseDenseMatmulNet, self).__init__()
self.matmul = nn.SparseTensorDenseMatmul(adjoint_st, adjoint_dt)
@ -34,6 +36,7 @@ class SparseDenseMatmulNet(nn.Cell):
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=False)
@ -49,6 +52,33 @@ def judge_result_correct(result, expect):
assert np.allclose(result, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_mul_dyn():
"""
Feature: test SparseTensorDenseMul op in cpu.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = SparseDenseMatmulNet()
x1_indices_dyn = Tensor(shape=[None, 2], dtype=ms.int64)
x1_values_dyn = Tensor(shape=[None], dtype=ms.float32)
x1_shape = Tensor([3, 4], dtype=ms.int64)
x2_dyn = Tensor(shape=[4, None], dtype=ms.float32)
net.set_inputs(x1_indices_dyn, x1_values_dyn, x1_shape, x2_dyn)
x1_indices = Tensor([[0, 1], [1, 2]], dtype=ms.int64)
x1_values = Tensor([1, 2], dtype=ms.float32)
x2 = Tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=ms.float32)
out = net(x1_indices, x1_values, x1_shape, x2)
expect_out_shape = (3, 2)
assert out.asnumpy().shape == expect_out_shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -56,8 +86,10 @@ def test_sparse_tensor_dense_matmul_no_transpose():
indices_np = np.array([[0, 0], [1, 1], [2, 2], [2, 3]], np.int64)
values_np = np.array([2, 3, 4, 5], np.float16)
dense_shape = (3, 4)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]], dtype=np.float16)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float16)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]],
dtype=np.float16)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
dtype=np.float16)
sparse_dense_matmul_net = SparseDenseMatmulNet()
indices = Tensor(indices_np)
@ -71,10 +103,9 @@ def test_sparse_tensor_dense_matmul_no_transpose():
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([3., 12., 21., 30.], dtype=np.float16)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.],
[5., 5., 5.]], dtype=np.float16)
expect_dense_grad = np.array(
[[2., 2., 2.], [3., 3., 3.], [4., 4., 4.], [5., 5., 5.]],
dtype=np.float16)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@ -82,11 +113,14 @@ def test_sparse_tensor_dense_matmul_no_transpose():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_transpose_a():
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]], np.int32)
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]],
np.int32)
values_np = np.array([1, 2, 3, 4, 5, 6], np.float64)
dense_shape = (4, 3)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]], dtype=np.float64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]],
dtype=np.float64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
dtype=np.float64)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_st=True)
indices = Tensor(indices_np)
@ -99,12 +133,12 @@ def test_sparse_tensor_dense_matmul_transpose_a():
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([3., 12., 21., 21., 30., 30.], dtype=np.float64)
expect_values_grad = np.array([3., 12., 21., 21., 30., 30.],
dtype=np.float64)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[1., 1., 1.],
[2., 2., 2.],
[7., 7., 7.],
[11., 11., 11.]], dtype=np.float64)
expect_dense_grad = np.array(
[[1., 1., 1.], [2., 2., 2.], [7., 7., 7.], [11., 11., 11.]],
dtype=np.float64)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@ -112,11 +146,14 @@ def test_sparse_tensor_dense_matmul_transpose_a():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_transpose_b():
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]], np.int64)
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]],
np.int64)
values_np = np.array([1, 2, 3, 4, 5, 6], np.int32)
dense_shape = (4, 3)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]], dtype=np.int32)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.int32)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]],
dtype=np.int32)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
dtype=np.int32)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_dt=True)
indices = Tensor(indices_np)
@ -129,12 +166,11 @@ def test_sparse_tensor_dense_matmul_transpose_b():
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([18., 22., 18., 26., 22., 26.], dtype=np.int32)
expect_values_grad = np.array([18., 22., 18., 26., 22., 26.],
dtype=np.int32)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[4, 7, 10],
[4, 7, 10],
[4, 7, 10],
[4, 7, 10]], dtype=np.int32)
expect_dense_grad = np.array(
[[4, 7, 10], [4, 7, 10], [4, 7, 10], [4, 7, 10]], dtype=np.int32)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@ -145,24 +181,26 @@ def test_sparse_tensor_dense_matmul_transpose_all():
indices_np = np.array([[0, 0], [1, 1], [2, 2], [2, 3]], np.int64)
values_np = np.array([2, 3, 4, 5], np.int64)
dense_shape = (3, 4)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]], dtype=np.int64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.int64)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]],
dtype=np.int64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
dtype=np.int64)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_st=True, adjoint_dt=True)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_st=True,
adjoint_dt=True)
indices = Tensor(indices_np)
values = Tensor(values_np)
dense = Tensor(dense_np)
out_ms = sparse_dense_matmul_net(indices, values, dense_shape, dense)
perm = (1, 0)
out_np = np.matmul(np.transpose(sparse_np, perm), np.transpose(dense_np, perm))
out_np = np.matmul(np.transpose(sparse_np, perm),
np.transpose(dense_np, perm))
judge_result_correct(out_ms.asnumpy(), out_np)
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([18, 22, 26, 26], dtype=np.int64)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[2, 3, 9],
[2, 3, 9],
[2, 3, 9],
[2, 3, 9]], dtype=np.int64)
expect_dense_grad = np.array([[2, 3, 9], [2, 3, 9], [2, 3, 9], [2, 3, 9]],
dtype=np.int64)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)

View File

@ -16,15 +16,27 @@
import numpy as np
import pytest
from mindspore import Tensor, context
import mindspore as ms
from mindspore import Tensor, context, nn
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
class Net(nn.Cell):
def __init__(self) -> None:
super(Net, self).__init__()
self.op = SparseTensorDenseAdd()
def construct(self, x1_indices, x1_values, x1_shape, x2):
return self.op(x1_indices, x1_values, x1_shape, x2)
def generate_data(datatype="float32", indicetype="int32"):
x1_indices = Tensor(np.array([[0, 1], [1, 2]]).astype(indicetype))
x1_values = Tensor(np.array([1, 2]).astype(datatype))
x1_shape = Tensor(np.array([3, 4]).astype(indicetype))
x2 = Tensor(np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]).astype(datatype))
x2 = Tensor(
np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]).astype(datatype))
data = x1_indices, x1_values, x1_shape, x2
return data
@ -60,6 +72,33 @@ def test_sparse_tensor_dense_add(indicetype, datatype):
data = generate_data(datatype=datatype, indicetype=indicetype)
net = SparseTensorDenseAdd()
out = net(data[0], data[1], data[2], data[3]).asnumpy()
expected = np.array([[1, 2, 1, 1], [1, 1, 3, 1], [1, 1, 1, 1]]).astype(datatype)
eps = 1e-6*np.array(np.ones_like(out))
expected = np.array([[1, 2, 1, 1], [1, 1, 3, 1], [1, 1, 1,
1]]).astype(datatype)
eps = 1e-6 * np.array(np.ones_like(out))
assert np.all(expected - out < eps)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_add_dyn():
"""
Feature: test SparseTensorDenseAdd op in gpu.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
net = Net()
x1_indices_dyn = Tensor(shape=[None, 2], dtype=ms.int64)
x1_values_dyn = Tensor(shape=[None], dtype=ms.float32)
x1_shape = Tensor([3, 3], dtype=ms.int64)
x2 = Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=ms.float32)
net.set_inputs(x1_indices_dyn, x1_values_dyn, x1_shape, x2)
x1_indices = Tensor([[0, 0], [0, 1]], dtype=ms.int64)
x1_values = Tensor([1, 1], dtype=ms.float32)
out = net(x1_indices, x1_values, x1_shape, x2)
expect_out_shape = (3, 3)
assert out.asnumpy().shape == expect_out_shape