[assistant][ops] Add sparse operator SparseAddmm

This commit is contained in:
luobingchun 2022-08-03 18:34:58 +08:00
parent f35924192e
commit 7e5b4d69dd
10 changed files with 750 additions and 1 deletions

View File

@ -0,0 +1,370 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h"
#include <algorithm>
#include <functional>
#include <utility>
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSparseAddmmInputsNum = 7;
constexpr size_t kSparseAddmmOutputsNum = 1;
constexpr size_t kSparseAddmmOutputShapeSize = 2;
constexpr size_t kSparseAddmmDenseShapeSize = 2;
constexpr size_t kIndicesSizeNum = 2;
constexpr size_t kIndices2rdDimNum = 2;
constexpr size_t kShapeValue = 0;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kIndex6 = 6;
} // namespace
void SparseAddmmCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, INDICES);
if (indices_shape.size() != kIndicesSizeNum && indices_shape[1] != kIndices2rdDimNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', it requires 'indices' should be a 2-D Tensor and the second dimension length "
"should be 2, but got 'indices' shape: "
<< Vector2Str(indices_shape);
}
auto values_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, VALUES);
if (values_shape.size() != 1 || values_shape[0] != indices_shape[0]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', it requires 'values' should be a 1-D Tensor and the first dimension length "
" should be equal to the first dimension length of 'indices', but got 'values' shape: "
<< Vector2Str(values_shape) << " and 'indices' shape: " << Vector2Str(indices_shape);
}
output_shape_ = Convert2SizeT(common::AnfAlgo::GetOutputInferShape(kernel_node, 0));
values_size_ = LongToSize(values_shape[0]);
b_shape_ = Convert2SizeT(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, DENSE));
if (b_shape_.size() != kSparseAddmmDenseShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'dense' should be "
<< kSparseAddmmDenseShapeSize << "-D, but got " << b_shape_.size() << "-D";
}
if (output_shape_.size() != kSparseAddmmOutputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be "
<< kSparseAddmmOutputShapeSize << "-D, but got " << output_shape_.size() << "-D";
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "SparseAddmm does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename I, typename T>
bool SparseAddmmCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseAddmmInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseAddmmOutputsNum, kernel_name_);
auto ret = memset_s(outputs[0]->addr, outputs[0]->size, 0, outputs[0]->size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret;
}
auto *a_indices = reinterpret_cast<I *>(inputs[kIndex0]->addr);
auto *a_values = reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto *x1_shape = reinterpret_cast<I *>(inputs[kIndex2]->addr);
auto *b = reinterpret_cast<T *>(inputs[kIndex3]->addr);
auto *c = reinterpret_cast<T *>(inputs[kIndex4]->addr);
auto *alpha = reinterpret_cast<T *>(inputs[kIndex5]->addr);
auto *beta = reinterpret_cast<T *>(inputs[kIndex6]->addr);
auto *out = reinterpret_cast<T *>(outputs[kIndex0]->addr);
const size_t indices_length = inputs[kIndex0]->size / sizeof(I);
const size_t values_length = inputs[kIndex1]->size / sizeof(T);
const size_t b_length = inputs[kIndex3]->size / sizeof(T);
const size_t dim_num = 2;
const size_t out_dim_0 = output_shape_[0];
const size_t out_dim_1 = output_shape_[1];
const size_t b_dim_0 = b_shape_[0];
const size_t b_dim_1 = b_shape_[1];
const size_t same_dim = b_dim_0;
const I x1_shape_0 = x1_shape[0];
const I x1_shape_1 = x1_shape[1];
const size_t x1_shape_0_s = IntToSize(x1_shape_0);
const size_t x1_shape_1_s = IntToSize(x1_shape_1);
if (x1_shape_0_s <= kShapeValue || x1_shape_1_s <= kShapeValue) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'x1_shape' should be greater than 0.";
}
if (x1_shape_1_s != b_dim_0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the col of 'x1_shape' should be equal to the row of 'x2_dense',"
" but got col: "
<< x1_shape_1_s << ", row: " << b_dim_0;
}
for (size_t i = 0; i < values_size_; ++i) {
if (i * dim_num + 1 >= indices_length) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'indices' out of bounds.";
}
if (i >= values_length) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'values' out of bounds.";
}
const int row = a_indices[i * dim_num];
const int col = a_indices[i * dim_num + 1];
if (row >= SizeToInt(out_dim_0) || row < 0 || col >= SizeToInt(same_dim) || col < 0) {
MS_EXCEPTION(ValueError) << "The indices including out of bounds index, row range: [0, " << out_dim_0
<< "), col range: [0, " << same_dim << "), but got row: " << row << ", col: " << col;
}
const size_t row_s = IntToSize(row);
const size_t col_s = IntToSize(col);
for (size_t n = 0; n < out_dim_1; ++n) {
if (col_s * b_dim_1 + n >= b_length) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'b' out of bounds.";
}
const T b_value = b[col_s * b_dim_1 + n];
out[row_s * out_dim_1 + n] += *(alpha)*a_values[i] * b_value;
}
}
for (size_t i = 0; i < out_dim_0; ++i) {
for (size_t j = 0; j < out_dim_1; ++j) {
const T c_value = c[i * out_dim_1 + j];
out[i * out_dim_1 + j] += *(beta)*c_value;
}
}
return true;
}
std::vector<std::pair<KernelAttr, SparseAddmmCpuKernelMod::SparseAddmmFunc>> SparseAddmmCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, float>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint64_t>}};
std::vector<KernelAttr> SparseAddmmCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SparseAddmmFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseAddmm, SparseAddmmCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SparseAddmmCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SparseAddmmCpuKernelMod() = default;
~SparseAddmmCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using SparseAddmmFunc = std::function<bool(SparseAddmmCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, SparseAddmmFunc>> func_list_;
SparseAddmmFunc kernel_func_;
std::vector<size_t> output_shape_;
std::vector<size_t> b_shape_;
size_t output_size_{0};
size_t values_size_{0};
enum input_list_ { INDICES, VALUES, SPARSE_SHAPE, DENSE, DENSE_X2, ALPHA, BETA };
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_

View File

@ -288,6 +288,7 @@ constexpr auto kCOOTensorDenseMatmul = "COOTensorDenseMatmul";
// Sparse ops
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
constexpr auto kSparseAddmm = "SparseAddmm";
constexpr auto kCSRReduceSum = "CSRReduceSum";
constexpr auto kCSRMV = "CSRMV";
constexpr auto kCSRMM = "CSRMM";
@ -923,6 +924,7 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(
// Sparse ops
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
GVAR_DEF(PrimitivePtr, kPrimSparseAddmm, std::make_shared<Primitive>(kSparseAddmm));
GVAR_DEF(PrimitivePtr, kPrimCOOTensorDenseMatmul, std::make_shared<Primitive>(kCOOTensorDenseMatmul));
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));

View File

@ -0,0 +1,124 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <set>
#include <vector>
#include <memory>
#include <map>
#include "ops/sparse_addmm.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SparseAddmmInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto x2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
const int kDimensionTwo = 2;
const int kDimensionOne = 1;
if (indices_shape.size() != kDimensionTwo) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input indices should "
<< "have rank 2, but got " << indices_shape.size() << ".";
}
if (indices_shape[1] != kDimensionTwo) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 2nd dimension of indices "
<< "should be 2, but got " << indices_shape[1] << ".";
}
if (values_shape.size() != kDimensionOne) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input value should "
<< "have rank 1, but got " << values_shape.size() << ".";
}
if (shape_shape.size() != kDimensionOne) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
<< "have rank 1, but got " << shape_shape.size() << ".";
}
if (shape_shape[0] != kDimensionTwo) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 1st dimension of input shape "
<< "should be 2, but got " << shape_shape[0] << ".";
}
if (x2_shape.size() != kDimensionTwo) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of input dense "
<< "should be [2], but got [" << x2_shape.size() << "].";
}
if (x3_shape.size() != kDimensionTwo) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of input dense "
<< "should be [2], but got [" << x3_shape.size() << "].";
}
if (alpha_shape.size() != kDimensionOne) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
<< "have rank 1, but got " << alpha_shape.size() << ".";
}
if (beta_shape.size() != kDimensionOne) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
<< "have rank 1, but got " << beta_shape.size() << ".";
}
return std::make_shared<abstract::Shape>(x3_shape);
}
TypePtr SparseAddmmInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
std::map<std::string, TypePtr> types;
std::set<TypePtr> valid_types = {kFloat32, kFloat64, kInt32, kInt64, kInt16,
kInt8, kUInt32, kUInt64, kUInt16, kUInt8};
TypePtr indices_type = input_args[kInputIndex0]->BuildType();
TypePtr values_type = input_args[kInputIndex1]->BuildType();
TypePtr shape_type = input_args[kInputIndex2]->BuildType();
TypePtr x2_type = input_args[kInputIndex3]->BuildType();
TypePtr x3_type = input_args[kInputIndex4]->BuildType();
TypePtr alpha_type = input_args[kInputIndex5]->BuildType();
TypePtr beta_type = input_args[kInputIndex6]->BuildType();
auto prim_name = primitive->name();
(void)types.emplace("x1_values", values_type);
(void)types.emplace("x2", x2_type);
(void)types.emplace("x3", x3_type);
(void)types.emplace("alpha", alpha_type);
(void)types.emplace("beta", beta_type);
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
const std::set<TypePtr> valid_type = {kInt64, kInt32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, valid_type, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("sparse_shape", shape_type, valid_type, prim_name);
auto tensor_type = x3_type->cast<TensorTypePtr>();
auto tensor_element = tensor_type->element();
return tensor_element;
}
} // namespace
AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 7;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
// infer type
auto infer_type = SparseAddmmInferType(primitive, input_args);
// infer shape
auto infer_shape = SparseAddmmInferShape(primitive, input_args);
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape);
}
MIND_API_OPERATOR_IMPL(SparseAddmm, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(SparseAddmm, prim::kPrimSparseAddmm, SparseAddmmInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_
#define MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSparseAddmm = "SparseAddmm";
class MIND_API SparseAddmm : public BaseOperator {
public:
MIND_API_BASE_MEMBER(SparseAddmm);
SparseAddmm() : BaseOperator(kNameSparseAddmm) {
InitIOName({"indices", "values", "sparse_shape", "x2_dense", "x3_dense", "alpha", "beta"}, {"output"});
}
void Init() const {}
};
abstract::AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using PrimSparseAddmmPtr = std::shared_ptr<SparseAddmm>;
} // namespace ops
} // namespace mindspore
#endif

View File

@ -98,6 +98,7 @@ from .random_poisson import _random_poisson_aicpu
from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .rsqrt import _rsqrt_aicpu
from .rsqrt_grad import _rsqrt_grad_aicpu
from .sparseaddmm import _sparse_addmm_aicpu
from .search_sorted import _search_sorted_aicpu
from .stack import _stack_aicpu
from .unstack import _unstack_aicpu

View File

@ -0,0 +1,87 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseAddmm op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sparseaddmm_op_info = AiCPURegOp("SparseAddmm") \
.fusion_type("OPAQUE") \
.input(0, "x1_indices", "required") \
.input(1, "x1_values", "required") \
.input(2, "x1_shape", "required") \
.input(3, "x2", "required") \
.input(4, "x3", "required") \
.input(5, "alpha", "required") \
.input(6, "beta", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I32_Default, DataType.I8_Default,
DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.I8_Default,
DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default, DataType.I16_Default,
DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.I16_Default,
DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I32_Default,
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.I32_Default, DataType.U8_Default,
DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.U8_Default,
DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.I32_Default, DataType.U16_Default,
DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.U16_Default,
DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.I32_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.U32_Default,
DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.I32_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.U64_Default,
DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default,
DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.C64_Default,
DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.C64_Default,
DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.C128_Default,
DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.C128_Default,
DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()
@op_info_register(sparseaddmm_op_info)
def _sparse_addmm_aicpu():
"""SparseAddmm AiCPU register"""
return

View File

@ -1,6 +1,6 @@
# coding: utf-8
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -531,6 +531,58 @@ class Sspaddmm(Primitive):
'x3_dense', 'alpha', 'beta'], outputs=['y_indices', 'y_values', 'y_shape'])
class SparseAddmm(Primitive):
"""
Multiplies sparse matrix `A` by dense matrix `B` * `alpha` and add dense matrix `C` * `beta`.
The rank of sparse matrix and dense matrix must equal to `2`.
Inputs:
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
Support float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64.
The shape should be :math:`(n,)`.
- **sparse_shape** (Tensor) - A positive int tuple which specifies the shape of sparse tensor.
Support int32, int64, should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
- **x2_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
- **x3_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
- **alpha** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
- **beta** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
Outputs:
Tensor, the dtype is the same as `values`.
Raises:
TypeError: If dtype of `indices`, dtype of `values` and dtype of `dense` don't meet the parameter description.
ValueError: If `sparse_shape`, shape of `indices, shape of `values`, and shape of `dense` don't meet the
parameter description.
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> sparse_shape = Tensor([1, 2], dtype=ms.int32)
>>> x2_dense = Tensor([[1,1], [2,2], [3,3], [4,4]], dtype=ms.float32)
>>> x3_dense = Tensor([[2,2], [6,6], [0,0]], dtype=ms.float32)
>>> alpha = Tensor([1], dtype=ms.float32)
>>> beta = Tensor([1], dtype=ms.float32)
>>> sparse_addmm = ops.SparseAddmm()
>>> out = sparse_addmm(indices, values, sparse_shape, x2_dense, x3_dense, alpha, beta)
>>> print(out)
[[4 4]
[12 12]
[0 0]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SparseAddmm"""
self.init_prim_io_names(inputs=['indices', 'values', 'sparse_shape', 'x2_dense', 'x3_dense', 'alpha', 'beta'],
outputs=['output'])
class SparseConcat(Primitive):
"""
concatenates the input SparseTensor(COO format) along the specified dimension. demo API now

View File

@ -224,6 +224,9 @@ class InputOpNet(nn.Cell):
x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4)
return x
def construct7_c0(self, x1, x2, x3, x4, x5, x6, x7):
x = self.op(x1, x2, x3, x4, x5, x6, x7)
return x
def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9)

View File

@ -125,6 +125,7 @@ from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
from mindspore.ops.operations.sparse_ops import SparseAddmm
from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
from mindspore.ops.operations.sparse_ops import SparseSparseMinimum
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
@ -3227,6 +3228,16 @@ test_case_array_ops = [
'block': UnravelIndex(),
'desc_inputs': [Tensor(np.array([5, 5]).astype(np.int64)), Tensor(np.array([3, 3]).astype(np.int64))],
'skip': ['backward']}),
('SparseAddmm', {
'block': SparseAddmm(),
'desc_inputs': [Tensor(np.array([[0, 1], [1, 2]]).astype(np.int32)),
Tensor(np.array([1, 2]).astype(np.int32)),
Tensor(np.array([3, 4]).astype(np.int32)),
Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.int32)),
Tensor(np.array([[2, 2], [6, 6], [0, 0]]).astype(np.int32)),
Tensor(np.array([1]).astype(np.int32)),
Tensor(np.array([1]).astype(np.int32))],
'skip': ['backward']}),
('SpaceToDepth', {
'block': P.SpaceToDepth(2),
'desc_inputs': [[1, 3, 2, 2]],