!16666 implement bprop of SparseTensorDenseMatmul

From: @zhangbuxue
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-21 10:26:30 +08:00 committed by Gitee
commit fe79dae4a5
9 changed files with 262 additions and 78 deletions

View File

@ -59,6 +59,8 @@ const char START[] = "start";
const char LIMIT[] = "limit";
const char DELTA[] = "delta";
const char SORTED[] = "sorted";
const char ADJ_ST[] = "adjoint_st";
const char ADJ_dT[] = "adjoint_dt";
enum OperateType {
ADD = 0,

View File

@ -20,7 +20,8 @@
namespace mindspore {
namespace kernel {
void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) {
template <typename T>
void GatherV2CPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
@ -34,7 +35,8 @@ void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) {
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, size_t size) {
template <typename T>
int GatherV2CPUKernel<T>::GatherLaunch(int8_t *input_data, int8_t *output_data, size_t size) {
int in_rank = input_shape_.size();
int indices_element_size = 1;
const int limit = input_shape_.at(axis_);
@ -64,7 +66,7 @@ int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, siz
int8_in += thread_stride * limit * inner_size * data_size;
int8_out += thread_stride * indices_element_size * inner_size * data_size;
auto error_code =
Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, sizeof(float));
Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, sizeof(T));
if (error_code != 0) {
MS_LOG(ERROR) << "GatherRun error task_id[" << i << "] error_code[" << error_code << "]";
}
@ -75,9 +77,10 @@ int GatherV2CPUKernel::GatherLaunch(int8_t *input_data, int8_t *output_data, siz
return 0;
}
bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
template <typename T>
bool GatherV2CPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
int8_t *input_tensor = reinterpret_cast<int8_t *>(inputs[0]->addr);
indices_data_ = reinterpret_cast<int32_t *>(inputs[1]->addr);
int8_t *output_addr = reinterpret_cast<int8_t *>(outputs[0]->addr);
@ -87,7 +90,8 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) {
template <typename T>
void GatherV2CPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower.";

View File

@ -23,9 +23,10 @@
namespace mindspore {
namespace kernel {
template <typename T>
class GatherV2CPUKernel : public CPUKernel {
public:
GatherV2CPUKernel() : axis_(0) {}
GatherV2CPUKernel() = default;
~GatherV2CPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
@ -40,13 +41,44 @@ class GatherV2CPUKernel : public CPUKernel {
std::vector<size_t> indices_shape_;
std::vector<size_t> output_shape_;
int *indices_data_ = nullptr;
int64_t axis_;
int64_t axis_{0};
};
MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
GatherV2CPUKernel, bool);
MS_REG_CPU_KERNEL_T(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
GatherV2CPUKernel);
GatherV2CPUKernel, float);
MS_REG_CPU_KERNEL_T(
Gather,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
GatherV2CPUKernel, double);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
GatherV2CPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherV2CPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
GatherV2CPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
GatherV2CPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherV2CPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
GatherV2CPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherV2CPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
GatherV2CPUKernel, uint64_t);
} // namespace kernel
} // namespace mindspore

View File

@ -14,26 +14,25 @@
* limitations under the License.
*/
#include <functional>
#include "backend/kernel_compiler/cpu/sparse_tensor_dense_matmul_cpu_kernel.h"
namespace mindspore {
namespace kernel {
template <typename I, typename T>
void SparseTensorDenseMatmulCPUKernel<I, T>::InitKernel(const CNodePtr &kernel_node) {
output_size_ = 1;
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (auto &dim : output_shape) {
output_size_ *= dim;
adj_st_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJ_ST);
adj_dt_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJ_dT);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
auto values_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (values_shape.size() != 1) {
MS_LOG(EXCEPTION) << "SparseTensorDenseMatmul requires the values must be a 1-D tensor, but got "
<< values_shape.size() << "-D";
}
aValues_size_ = 1;
auto aValues_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
for (auto &dim : aValues_shape) {
aValues_size_ *= dim;
}
b_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 3);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
values_size_ = values_shape[0];
b_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
}
template <typename I, typename T>
@ -44,29 +43,31 @@ bool SparseTensorDenseMatmulCPUKernel<I, T>::Launch(const std::vector<kernel::Ad
auto a_values = reinterpret_cast<T *>(inputs[1]->addr);
auto b = reinterpret_cast<T *>(inputs[3]->addr);
auto out = reinterpret_cast<T *>(outputs[0]->addr);
memset(out, 0, output_size_);
const size_t nnz = aValues_size_;
const size_t rhs_right = b_shape_[1];
const size_t lhs_right = b_shape_[0];
const size_t out_dim_0 = output_shape_[0];
const size_t out_dim_1 = output_shape_[1];
const size_t b_dim_0 = b_shape_[0];
const size_t b_dim_1 = b_shape_[1];
const size_t same_dim = adj_dt_ ? b_dim_1 : b_dim_0;
for (size_t i = 0; i < nnz; ++i) {
const size_t m = a_indices[i * 2];
const size_t k = a_indices[i * 2 + 1];
if (k > lhs_right) {
MS_LOG(ERROR) << "Invalid value: k: " << k << ", lhs_right: " << lhs_right;
return false;
}
if (m > output_shape_[0]) {
MS_LOG(ERROR) << "Invalid value: m: " << m << ", output_shape: " << output_shape_[0];
for (size_t i = 0; i < values_size_; ++i) {
const int row = adj_st_ ? a_indices[i * 2 + 1] : a_indices[i * 2];
const int col = adj_st_ ? a_indices[i * 2] : a_indices[i * 2 + 1];
if (row > SizeToInt(out_dim_0) || row < 0 || col > SizeToInt(same_dim) || col < 0) {
MS_LOG(ERROR) << "The indices including out of bounds index, row range: [0, " << out_dim_0 << "), col range: [0, "
<< same_dim << "), but got row: " << row << ", col: " << col;
return false;
}
for (size_t n = 0; n < rhs_right; ++n) {
const float b_value = b[k * lhs_right + n];
out[m * output_shape_[0] + n] += a_values[i] * b_value;
for (size_t n = 0; n < out_dim_1; ++n) {
if (adj_dt_) {
const T b_value = b[n * b_dim_1 + col];
out[row * out_dim_1 + n] += a_values[i] * b_value;
} else {
const T b_value = b[col * b_dim_1 + n];
out[row * out_dim_1 + n] += a_values[i] * b_value;
}
}
}
return true;

View File

@ -37,8 +37,10 @@ class SparseTensorDenseMatmulCPUKernel : public CPUKernel {
private:
std::vector<size_t> output_shape_;
std::vector<size_t> b_shape_;
size_t output_size_;
size_t aValues_size_;
size_t output_size_{0};
size_t values_size_{0};
bool adj_st_{false};
bool adj_dt_{false};
};
MS_REG_CPU_KERNEL_T_S(SparseTensorDenseMatmul,
KernelAttr()

View File

@ -36,7 +36,6 @@ class SparseToDense(Cell):
>>> import mindspore as ms
>>> from mindspore import Tensor, SparseTensor
>>> import mindspore.nn as nn
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.int32)
>>> dense_shape = (3, 4)
@ -48,6 +47,7 @@ class SparseToDense(Cell):
[0 0 2 0]
[0 0 0 0]]
"""
def __init__(self):
super(SparseToDense, self).__init__()
self.sparse_to_dense = P.SparseToDense()
@ -57,6 +57,7 @@ class SparseToDense(Cell):
sparse_tensor.values,
sparse_tensor.dense_shape)
class SparseTensorDenseMatmul(Cell):
"""
Multiply SparseTensor(of rank 2) "A" by dense tensor.
@ -94,6 +95,7 @@ class SparseTensorDenseMatmul(Cell):
>>> test_SparseDenseMatmul = NetSparseDenseMatmul()
>>> out = test_SparseDenseMatmul(indices, values, dens_shape, dsMatrix)
"""
def __init__(self, adjoint_st=False, adjoint_dt=False):
"""Initialize SparseTensorDenseMatmul"""
super(SparseTensorDenseMatmul, self).__init__()

View File

@ -19,6 +19,7 @@ from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from .grad_base import bprops, bprop_getters
# Unused parameters are placeholders.
@ -56,3 +57,24 @@ def get_bprop_sparse_to_dense(self):
return zeros_like(indices), dout, zeros_like(dense_shape)
return bprop
@bprop_getters.register(P.SparseTensorDenseMatmul)
def get_bprop_sparse_tensor_dense_matmul(self):
"""Generate bprop for SparseTensorDenseMatmul"""
adj_s = self.adjoint_st
adj_d = self.adjoint_dt
sparse_tensor_dense_mat_mul = P.SparseTensorDenseMatmul(not adj_s)
def bprop(indices, values, dense_shape, dense, out, dout):
dense_grad = sparse_tensor_dense_mat_mul(indices, values, dense_shape, dout)
perm = (1, 0)
if adj_d:
dense_grad = F.transpose(dense_grad, perm)
rows = indices[:, 0]
cols = indices[:, 1]
parts_a = F.gather(dout, cols if adj_s else rows, 0)
parts_b = F.gather(F.transpose(dense, perm) if adj_d else dense, rows if adj_s else cols, 0)
values_grad = F.reduce_sum(parts_a * parts_b, 1)
return zeros_like(indices), values_grad, zeros_like(dense_shape), dense_grad
return bprop

View File

@ -22,6 +22,7 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
class SparseToDense(PrimitiveWithInfer):
"""
Converts a sparse representation into a dense tensor.
@ -57,6 +58,7 @@ class SparseToDense(PrimitiveWithInfer):
'value': None}
return out
class SparseTensorDenseMatmul(PrimitiveWithInfer):
"""
Multiply SparseTensor(of rank 2) "A" by dense tensor.
@ -95,6 +97,7 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer):
>>> dsMatrix = Tensor([[1,1], [2,2], [3,3 ], [4, 4]], dtype=ms.float32)
>>> out = ops.SparseTensorDenseMatmul(indices, values, dense_shape, dsMatrix)
"""
@prim_attr_register
def __init__(self, adjoint_st=False, adjoint_dt=False):
"""Initialize SparseTensorDenseMatmul"""
@ -112,15 +115,16 @@ class SparseTensorDenseMatmul(PrimitiveWithInfer):
valid_types = mstype.number_type + (mstype.bool_,)
args = {'values': values['dtype'], 'dense': dense['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
a_shape = dense_shape['value']
b_shape = dense['shape']
a_shape = dense_shape['value'][::-1] if self.adjoint_st else dense_shape['value']
b_shape = dense['shape'][::-1] if self.adjoint_dt else dense['shape']
if len(a_shape) != 2 or len(b_shape) != 2:
raise ValueError('SparseTensorDenseMatmul SparseTensor, DenseTensor should have the same dimension size '
+ f'and equal to 2, while SparseTensor size is ({len(a_shape)}) and DenseTensor size is '
raise ValueError('SparseTensorDenseMatmul requires SparseTensor and DenseTensor have the same dimension'
+ f'and equal to 2, while SparseTensor dim is ({len(a_shape)}) and DenseTensor dim is '
+ f'({len(b_shape)}).')
out_shape = []
out_shape.append(a_shape[0])
out_shape.append(b_shape[1])
if a_shape[1] != b_shape[0]:
raise ValueError('SparseTensorDenseMatmul requires SparseTensor dim_1 should be equal to DenseTensor dim_0,'
f'but got SparseTensor dim_1: {a_shape[1]}, DenseTensor dim_0: {b_shape[0]}')
out_shape = [a_shape[0], b_shape[1]]
out = {'shape': tuple(out_shape),
'dtype': values['dtype'],
'value': None}

View File

@ -14,40 +14,155 @@
# ============================================================================
import numpy as np
import mindspore as ms
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import SparseTensor
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetSparseDenseMatmul(nn.Cell):
def __init__(self):
super(NetSparseDenseMatmul, self).__init__()
self.matmul = nn.SparseTensorDenseMatmul()
def construct(self, indices, values, dens_shape, dt):
return self.matmul(indices, values, dens_shape, dt)
class SparseDenseMatmulNet(nn.Cell):
def __init__(self, adjoint_st=False, adjoint_dt=False):
super(SparseDenseMatmulNet, self).__init__()
self.matmul = nn.SparseTensorDenseMatmul(adjoint_st, adjoint_dt)
class NetSparseTensor(nn.Cell):
def __init__(self, dense_shape):
super(NetSparseTensor, self).__init__()
self.dense_shape = dense_shape
def construct(self, indices, values):
x = SparseTensor(indices, values, self.dense_shape)
return x.values, x.indices, x.dense_shape
def construct(self, indices, values, dens_shape, dense):
return self.matmul(indices, values, dens_shape, dense)
def test_sparse_tensor_dense_matmul():
indices = Tensor([[0, 1], [1, 1]])
values = Tensor([5, 5], dtype=ms.float32)
dens_shape = (3, 3)
spMatrix = np.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]], dtype=np.float32)
dsMatrix = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32)
test_SparseDenseMatmul = NetSparseDenseMatmul()
out_ms = test_SparseDenseMatmul(indices, values, dens_shape, Tensor(dsMatrix))
out_np = np.matmul(spMatrix, dsMatrix)
error = np.ones(shape=dsMatrix.shape) * 10e-6
diff = out_ms.asnumpy() - out_np
assert np.all(diff < error)
class GradNet(nn.Cell):
def __init__(self, network):
super(GradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=False)
self.network = network
def construct(self, indices, values, dens_shape, dense):
return self.grad(self.network)(indices, values, dens_shape, dense)
def judge_result_correct(result, expect):
assert result.dtype == expect.dtype
assert result.shape == expect.shape
assert np.allclose(result, expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_no_transpose():
indices_np = np.array([[0, 0], [1, 1], [2, 2], [2, 3]], np.int64)
values_np = np.array([2, 3, 4, 5], np.float32)
dense_shape = (3, 4)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]], dtype=np.float32)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float32)
sparse_dense_matmul_net = SparseDenseMatmulNet()
indices = Tensor(indices_np)
values = Tensor(values_np)
dense = Tensor(dense_np)
out_ms = sparse_dense_matmul_net(indices, values, dense_shape, dense)
out_np = np.matmul(sparse_np, dense_np)
judge_result_correct(out_ms.asnumpy(), out_np)
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([3., 12., 21., 30.], dtype=np.float32)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.],
[5., 5., 5.]], dtype=np.float32)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_transpose_a():
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]], np.int32)
values_np = np.array([1, 2, 3, 4, 5, 6], np.float64)
dense_shape = (4, 3)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]], dtype=np.float64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_st=True)
indices = Tensor(indices_np)
values = Tensor(values_np)
dense = Tensor(dense_np)
out_ms = sparse_dense_matmul_net(indices, values, dense_shape, dense)
perm = (1, 0)
out_np = np.matmul(np.transpose(sparse_np, perm), dense_np)
judge_result_correct(out_ms.asnumpy(), out_np)
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([3., 12., 21., 21., 30., 30.], dtype=np.float64)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[1., 1., 1.],
[2., 2., 2.],
[7., 7., 7.],
[11., 11., 11.]], dtype=np.float64)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_transpose_b():
indices_np = np.array([[0, 0], [1, 1], [2, 0], [2, 2], [3, 1], [3, 2]], np.int64)
values_np = np.array([1, 2, 3, 4, 5, 6], np.int32)
dense_shape = (4, 3)
sparse_np = np.array([[1, 0, 0], [0, 2, 0], [3, 0, 4], [0, 5, 6]], dtype=np.int32)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.int32)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_dt=True)
indices = Tensor(indices_np)
values = Tensor(values_np)
dense = Tensor(dense_np)
out_ms = sparse_dense_matmul_net(indices, values, dense_shape, dense)
perm = (1, 0)
out_np = np.matmul(sparse_np, np.transpose(dense_np, perm))
judge_result_correct(out_ms.asnumpy(), out_np)
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([18., 22., 18., 26., 22., 26.], dtype=np.int32)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[4, 7, 10],
[4, 7, 10],
[4, 7, 10],
[4, 7, 10]], dtype=np.int32)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sparse_tensor_dense_matmul_transpose_all():
indices_np = np.array([[0, 0], [1, 1], [2, 2], [2, 3]], np.int64)
values_np = np.array([2, 3, 4, 5], np.int64)
dense_shape = (3, 4)
sparse_np = np.array([[2, 0, 0, 0], [0, 3, 0, 0], [0, 0, 4, 5]], dtype=np.int64)
dense_np = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.int64)
sparse_dense_matmul_net = SparseDenseMatmulNet(adjoint_st=True, adjoint_dt=True)
indices = Tensor(indices_np)
values = Tensor(values_np)
dense = Tensor(dense_np)
out_ms = sparse_dense_matmul_net(indices, values, dense_shape, dense)
perm = (1, 0)
out_np = np.matmul(np.transpose(sparse_np, perm), np.transpose(dense_np, perm))
judge_result_correct(out_ms.asnumpy(), out_np)
grad_net = GradNet(sparse_dense_matmul_net)
grad_ms = grad_net(indices, values, dense_shape, dense)
expect_values_grad = np.array([18, 22, 26, 26], dtype=np.int64)
judge_result_correct(grad_ms[1].asnumpy(), expect_values_grad)
expect_dense_grad = np.array([[2, 3, 9],
[2, 3, 9],
[2, 3, 9],
[2, 3, 9]], dtype=np.int64)
judge_result_correct(grad_ms[2].asnumpy(), expect_dense_grad)