[feat] [assistant] [I5EWKK] implement CholeskySolve operator in a new way

This commit is contained in:
linjie 2022-10-18 18:23:32 +08:00
parent b04d427eeb
commit e0d36aa741
6 changed files with 218 additions and 145 deletions

View File

@ -15,16 +15,17 @@
*/
#include "plugin/device/gpu/kernel/math/cholesky_solve_gpu_kernel.h"
#include "mindspore/core/ops/cholesky_solve.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
CholeskySolve,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskySolveGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
CholeskySolve,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskySolveGpuKernelMod, double)
using CSGKM = CholeskySolveGpuKernelMod;
std::vector<std::pair<KernelAttr, CSGKM::CholeskySolveFunc>> CSGKM::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&CholeskySolveGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&CholeskySolveGpuKernelMod::LaunchKernel<double>},
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CholeskySolve, CholeskySolveGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -20,13 +20,15 @@
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <algorithm>
#include <map>
#include <utility>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/triangle_matrix_copy_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/common/utils/convert_utils.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh"
namespace mindspore {
namespace kernel {
@ -36,163 +38,174 @@ constexpr size_t kCholeskyOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
inline cublasStatus_t cublasXtrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const float *alpha,
const float *A, int lda, float *B, int ldb) {
return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}
inline cublasStatus_t cublasXtrsm(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n, const double *alpha,
const double *A, int lda, double *B, int ldb) {
return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb);
}
inline cublasStatus_t cublasXtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const float *alpha, const float *const A[], int lda, float *const B[], int ldb,
int batchCount) {
return cublasStrsmBatched(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount);
}
inline cublasStatus_t cublasXtrsmBatched(cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo,
cublasOperation_t trans, cublasDiagType_t diag, int m, int n,
const double *alpha, const double *const A[], int lda, double *const B[],
int ldb, int batchCount) {
return cublasDtrsmBatched(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount);
}
template <typename T>
class CholeskySolveGpuKernelMod : public NativeGpuKernelMod {
public:
using pointer = T *;
CholeskySolveGpuKernelMod() = default;
~CholeskySolveGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
constexpr size_t input_num = 1;
constexpr size_t output_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
kernel_name_ = base_operator->GetPrim()->name();
if (base_operator->HasAttr("upper")) {
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
kernel_name_ = base_operator->name();
upper_ = GetValue<bool>(base_operator->GetAttr("upper"));
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For 'CholeskySolve', it does not support this kernel type: " << kernel_attr;
return false;
}
// Gpu input is col major default, so need to change row major.
// In order to speedup it, just change lower to upper, because of cholesky input a is triangle matrix
// when input b_col is not equal to one, maybe need a normal transpose op inplace.
kernel_func_ = func_list_[index].second;
return true;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
return kernel_func_(this, inputs, workspace, outputs);
}
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
using pointer = T *;
CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cholesky solve cublasSetStream failed");
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_a_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
auto d_b_array_addr = GetDeviceAddress<pointer>(workspace, kDim1);
auto d_c_array_addr = GetDeviceAddress<pointer>(workspace, kDim2);
std::vector<pointer> h_a_array(batch_num_);
std::vector<pointer> h_b_array(batch_num_);
std::vector<pointer> h_c_array(batch_num_);
for (size_t i = 0; i < batch_num_; i++) {
h_a_array[i] = input_a_addr + i * lda_ * nrhs_;
h_b_array[i] = input_b_addr + i * ldb_ * m_;
h_c_array[i] = output_addr + i * lda_ * nrhs_;
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_a_array_addr, h_a_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_b_array_addr, h_b_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(d_c_array_addr, h_c_array.data(), sizeof(pointer) * batch_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cuda memcopy Fail");
MatrixTranspose(input_a_addr, SizeToInt(batch_num_ * lda_ * nrhs_), SizeToInt(lda_), SizeToInt(nrhs_), output_addr,
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
if (upper_) {
uplo_ = CUBLAS_FILL_MODE_LOWER;
} else {
uplo_ = CUBLAS_FILL_MODE_UPPER;
transa_ = CUBLAS_OP_N;
transa_t_ = CUBLAS_OP_T;
}
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCusolverDnHandle();
T alpha = 1;
if (batch_num_ == 1) {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(cublasXtrsm(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_, CUBLAS_DIAG_NON_UNIT,
lda_, nrhs_, &alpha, input_b_addr, ldb_, output_addr, lda_),
"cholesky solve cublasXtrsm failed!");
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasXtrsm(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_t_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha,
input_b_addr, ldb_, output_addr, lda_),
"cholesky solve cublasXtrsm failed!");
} else {
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasXtrsmBatched(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha,
d_b_array_addr, ldb_, d_c_array_addr, lda_, batch_num_),
"cholesky solve cublasXgetrsBatched failed!");
CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE(
cublasXtrsmBatched(handle_, CUBLAS_SIDE_LEFT, uplo_, transa_t_, CUBLAS_DIAG_NON_UNIT, lda_, nrhs_, &alpha,
d_b_array_addr, ldb_, d_c_array_addr, lda_, batch_num_),
"cholesky solve cublasXgetrsBatched failed!");
}
MatrixTranspose(output_addr, SizeToInt(batch_num_ * lda_ * nrhs_), SizeToInt(nrhs_), SizeToInt(lda_), input_a_addr,
device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
auto output_elements = batch_num_ * lda_ * nrhs_;
MatrixCopy(input_a_addr, output_addr, output_elements, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
ResetResource();
auto in_a_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
auto in_b_shape = LongVecToSizeVec(inputs[kIndex1]->GetShapeVector());
(void)InitDim(in_a_shape, in_b_shape);
const auto b_shape = inputs.at(kIndex0)->GetShapeVector();
const auto cho_shape = inputs.at(kIndex1)->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(LongVecToSizeVec(b_shape), kernel_name_, "input_a") ||
CHECK_SHAPE_NULL(LongVecToSizeVec(cho_shape), kernel_name_, "input_b");
batch_num_ = std::accumulate(b_shape.begin(), b_shape.end() - kIndex2, int64_t(1), std::multiplies{});
m_ = cho_shape.back();
ldb_ = m_;
lda_ = m_;
nrhs_ = b_shape.back();
workspace_size_list_.clear();
workspace_size_list_ = {batch_num_ * sizeof(float *), batch_num_ * sizeof(float *), batch_num_ * sizeof(float *),
batch_num_ * sizeof(int)};
return KRET_OK;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cholesky solve cusolverDnSetStream failed");
auto input_a_addr = GetDeviceAddress<T>(inputs, kDim0);
auto input_b_addr = GetDeviceAddress<T>(inputs, kDim1);
auto output_addr = GetDeviceAddress<T>(outputs, kDim0);
auto d_a_array_addr = GetDeviceAddress<pointer>(workspace, kDim0);
auto d_b_array_addr = GetDeviceAddress<pointer>(workspace, kDim1);
auto d_info_array_addr = GetDeviceAddress<int>(workspace, kDim2);
for (size_t i = 0; i < outer_batch_; i++) {
h_a_array_[i] = input_a_addr + i * lda_ * m_;
h_b_array_[i] = input_b_addr + i * ldb_ * nrhs_;
}
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(d_a_array_addr, h_a_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
CHECK_CUDA_RET_WITH_ERROR_NOTRACE(
cudaMemcpyAsync(d_b_array_addr, h_b_array_.data(), sizeof(pointer) * outer_batch_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cuda memcopy Fail");
// Only support rhs = 1
if constexpr (std::is_same_v<T, float>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnSpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_, d_b_array_addr, ldb_,
d_info_array_addr, outer_batch_),
"cusolver cholesky solve batched Fail");
} else if constexpr (std::is_same_v<T, double>) {
CHECK_CUSOLVER_RET_WITH_EXCEPT_NOTRACE(
cusolverDnDpotrsBatched(handle_, uplo_, m_, nrhs_, d_a_array_addr, lda_, d_b_array_addr, ldb_,
d_info_array_addr, outer_batch_),
"cusolver cholesky solve batched Fail");
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the data type only should be float or double, right now.";
}
size_t output_elements = outputs.at(kDim0)->size / unit_size_;
// Copy results from written input's matrix to output's matrix.
MatrixCopy(input_b_addr, output_addr, output_elements, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
void ResetResource() {
input_size_list_.clear();
workspace_size_list_.clear();
output_size_list_.clear();
h_b_array_.clear();
h_a_array_.clear();
}
protected:
void InitSizeLists() {
size_t input_size = outer_batch_ * m_ * lda_ * unit_size_;
input_size_list_.emplace_back(input_size);
input_size = outer_batch_ * nrhs_ * ldb_ * unit_size_;
input_size_list_.emplace_back(input_size);
size_t workspace_size = outer_batch_ * sizeof(pointer);
workspace_size_list_.emplace_back(workspace_size);
workspace_size_list_.emplace_back(workspace_size);
workspace_size = outer_batch_ * sizeof(int);
workspace_size_list_.emplace_back(workspace_size);
size_t output_size = outer_batch_ * m_ * unit_size_;
output_size_list_.push_back(output_size);
std::vector<KernelAttr> GetOpSupport() override {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CholeskySolveFunc> &pair) { return pair.first; });
return support_list;
}
private:
void InitDim(const std::vector<size_t> &in_a_shape, const std::vector<size_t> &in_b_shape) {
constexpr size_t min_dim = 1;
if (in_a_shape.size() <= min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << " input a shape dim is " << in_a_shape.size() << " which is invalid.";
}
cho_row_ = in_a_shape.at(in_a_shape.size() - kRowIndex);
cho_col_ = in_a_shape.at(in_a_shape.size() - kColIndex);
outer_batch_ = min_dim;
for (int batch = 0; batch < static_cast<int>(in_a_shape.size() - kRowIndex); ++batch) {
outer_batch_ *= in_a_shape.at(batch);
}
if (cho_row_ != cho_col_) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is invalid. "
<< "Cholesky expects a square matrix. but input a shape is: " << cho_row_ << ", " << cho_col_;
}
const bool is_right_equal_left = in_a_shape.size() == in_b_shape.size();
size_t b_row;
if (is_right_equal_left) {
b_row = in_b_shape.at(in_b_shape.size() - kRowIndex);
} else {
b_row = in_b_shape.back();
}
if (cho_row_ != b_row) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', right hand matrix should be equal to left matrix";
}
m_ = SizeToInt(cho_row_);
lda_ = m_;
ldb_ = m_;
h_a_array_.resize(outer_batch_);
h_b_array_.resize(outer_batch_);
InitSizeLists();
}
size_t cho_row_{0};
size_t cho_col_{0};
size_t unit_size_{sizeof(T)};
size_t nrhs_{1};
size_t outer_batch_{0};
using CholeskySolveFunc =
std::function<bool(CholeskySolveGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
size_t nrhs_{0};
size_t batch_num_{0};
size_t m_{0};
size_t lda_{0};
size_t ldb_{0};
cusolverDnHandle_t handle_{nullptr};
cublasFillMode_t uplo_ = CUBLAS_FILL_MODE_UPPER;
std::vector<pointer> h_a_array_;
std::vector<pointer> h_b_array_;
cublasHandle_t handle_{nullptr};
cublasOperation_t transa_{CUBLAS_OP_T};
cublasOperation_t transa_t_{CUBLAS_OP_N};
bool upper_{false};
bool is_null_input_;
CholeskySolveFunc kernel_func_;
static std::vector<std::pair<KernelAttr, CholeskySolveFunc>> func_list_;
void *cuda_stream_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -30,7 +30,7 @@ abstract::ShapePtr CholeskySolveInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const size_t kDefalutRank = 2;
const size_t kBatchRank = 3;
const size_t kBatchRank = 1;
const size_t kBatchIndex = 3;
const size_t kRowIndex = 2;
const size_t kColIndex = 1;
@ -44,12 +44,12 @@ abstract::ShapePtr CholeskySolveInferShape(const PrimitivePtr &primitive,
out_shape.push_back(abstract::Shape::kShapeRankAny);
return std::make_shared<abstract::Shape>(out_shape);
}
if (x1_shape.size() != kDefalutRank && x1_shape.size() != kBatchRank) {
MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x1 must be equal to 2 or 3"
if (x1_shape.size() <= kBatchRank) {
MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x1 have at least 2 dimensions"
<< ", while got x1 rank " << x1_shape.size() << ".";
}
if (x2_shape.size() != kDefalutRank && x2_shape.size() != kBatchRank) {
MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x2 must be equal to 2 or 3"
if (x2_shape.size() <= kBatchRank) {
MS_EXCEPTION(ValueError) << "For CholeskySolve, the rank of x2 have at least 2 dimensions"
<< ", while got x2 rank " << x2_shape.size() << ".";
}
if (x1_shape.size() != x2_shape.size()) {

View File

@ -1256,8 +1256,11 @@ def get_bprop_cholesky_solve(self):
else:
dx2 = neg_op(matmul_op(common_term, x2))
else:
common_term = batchmatmul_op(dx1, transpose(out, (0, 2, 1)))
common_term = common_term + transpose(common_term, (0, 2, 1))
x2_dim_size = len(shape_op(x2))
x2_dim_order = list(range(x2_dim_size))
target_order = x2_dim_order[:-2] + x2_dim_order[-2:][::-1]
common_term = batchmatmul_op(dx1, transpose(out, tuple(target_order)))
common_term = common_term + transpose(common_term, tuple(target_order))
if upper is True:
dx2 = neg_op(batchmatmul_op(x2, common_term))
else:

View File

@ -6655,6 +6655,7 @@ class CholeskySolve(Primitive):
with float32 or float64 data type.
- **x2** (Tensor) - Tensor of shape :math:`(*, N, N)`, indicating 2D or 3D square matrices composed of
upper or lower triangular Cholesky factor, with float32 or float64 data type.
x1 and x2 must have the same type.
Outputs:
Tensor, has the same shape and data type as `x1`.
@ -6670,7 +6671,7 @@ class CholeskySolve(Primitive):
ValueError: If `x2` is not 2D or 3D square matrices.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x1 = Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), mindspore.float32)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
from mindspore import nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops.operations.math_ops import CholeskySolve
class Net(nn.Cell):
"""a class used to test CholeskySolve gpu operator."""
def __init__(self, upper=False):
super(Net, self).__init__()
self.cholesky_solve = CholeskySolve(upper=upper)
def construct(self, x1, x2):
"""construct."""
return self.cholesky_solve(x1, x2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cholesky_solve():
"""
Feature: CholeskySolve gpu TEST.
Description: test CholeskySolve operator
Expectation: the result match to numpy
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x1 = Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), mindspore.float32)
x2 = Tensor(np.array([[2, 0, 0], [4, 1, 0], [-1, 1, 2]]), mindspore.float32)
expect = np.array([[5.8125, -2.625, 0.625], [-2.625, 1.25, -0.25], [0.625, -0.25, 0.25]])
net = Net()
mindspore_output = net(x1, x2)
diff = mindspore_output.asnumpy() - expect
error = np.ones(shape=expect.shape)
assert np.all(diff < error)