[assistant][ops] Add sparse operator SparseAddmm
This commit is contained in:
parent
f35924192e
commit
7e5b4d69dd
|
@ -0,0 +1,370 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sparse_addmm_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSparseAddmmInputsNum = 7;
|
||||
constexpr size_t kSparseAddmmOutputsNum = 1;
|
||||
constexpr size_t kSparseAddmmOutputShapeSize = 2;
|
||||
constexpr size_t kSparseAddmmDenseShapeSize = 2;
|
||||
constexpr size_t kIndicesSizeNum = 2;
|
||||
constexpr size_t kIndices2rdDimNum = 2;
|
||||
constexpr size_t kShapeValue = 0;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
constexpr size_t kIndex4 = 4;
|
||||
constexpr size_t kIndex5 = 5;
|
||||
constexpr size_t kIndex6 = 6;
|
||||
} // namespace
|
||||
|
||||
void SparseAddmmCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, INDICES);
|
||||
if (indices_shape.size() != kIndicesSizeNum && indices_shape[1] != kIndices2rdDimNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', it requires 'indices' should be a 2-D Tensor and the second dimension length "
|
||||
"should be 2, but got 'indices' shape: "
|
||||
<< Vector2Str(indices_shape);
|
||||
}
|
||||
auto values_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, VALUES);
|
||||
if (values_shape.size() != 1 || values_shape[0] != indices_shape[0]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', it requires 'values' should be a 1-D Tensor and the first dimension length "
|
||||
" should be equal to the first dimension length of 'indices', but got 'values' shape: "
|
||||
<< Vector2Str(values_shape) << " and 'indices' shape: " << Vector2Str(indices_shape);
|
||||
}
|
||||
output_shape_ = Convert2SizeT(common::AnfAlgo::GetOutputInferShape(kernel_node, 0));
|
||||
values_size_ = LongToSize(values_shape[0]);
|
||||
b_shape_ = Convert2SizeT(common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, DENSE));
|
||||
if (b_shape_.size() != kSparseAddmmDenseShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'dense' should be "
|
||||
<< kSparseAddmmDenseShapeSize << "-D, but got " << b_shape_.size() << "-D";
|
||||
}
|
||||
if (output_shape_.size() != kSparseAddmmOutputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be "
|
||||
<< kSparseAddmmOutputShapeSize << "-D, but got " << output_shape_.size() << "-D";
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "SparseAddmm does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename I, typename T>
|
||||
bool SparseAddmmCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSparseAddmmInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSparseAddmmOutputsNum, kernel_name_);
|
||||
auto ret = memset_s(outputs[0]->addr, outputs[0]->size, 0, outputs[0]->size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset output failed. Error no: " << ret;
|
||||
}
|
||||
|
||||
auto *a_indices = reinterpret_cast<I *>(inputs[kIndex0]->addr);
|
||||
auto *a_values = reinterpret_cast<T *>(inputs[kIndex1]->addr);
|
||||
auto *x1_shape = reinterpret_cast<I *>(inputs[kIndex2]->addr);
|
||||
auto *b = reinterpret_cast<T *>(inputs[kIndex3]->addr);
|
||||
auto *c = reinterpret_cast<T *>(inputs[kIndex4]->addr);
|
||||
auto *alpha = reinterpret_cast<T *>(inputs[kIndex5]->addr);
|
||||
auto *beta = reinterpret_cast<T *>(inputs[kIndex6]->addr);
|
||||
auto *out = reinterpret_cast<T *>(outputs[kIndex0]->addr);
|
||||
|
||||
const size_t indices_length = inputs[kIndex0]->size / sizeof(I);
|
||||
const size_t values_length = inputs[kIndex1]->size / sizeof(T);
|
||||
const size_t b_length = inputs[kIndex3]->size / sizeof(T);
|
||||
|
||||
const size_t dim_num = 2;
|
||||
const size_t out_dim_0 = output_shape_[0];
|
||||
const size_t out_dim_1 = output_shape_[1];
|
||||
const size_t b_dim_0 = b_shape_[0];
|
||||
const size_t b_dim_1 = b_shape_[1];
|
||||
const size_t same_dim = b_dim_0;
|
||||
|
||||
const I x1_shape_0 = x1_shape[0];
|
||||
const I x1_shape_1 = x1_shape[1];
|
||||
|
||||
const size_t x1_shape_0_s = IntToSize(x1_shape_0);
|
||||
const size_t x1_shape_1_s = IntToSize(x1_shape_1);
|
||||
if (x1_shape_0_s <= kShapeValue || x1_shape_1_s <= kShapeValue) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of 'x1_shape' should be greater than 0.";
|
||||
}
|
||||
if (x1_shape_1_s != b_dim_0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the col of 'x1_shape' should be equal to the row of 'x2_dense',"
|
||||
" but got col: "
|
||||
<< x1_shape_1_s << ", row: " << b_dim_0;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < values_size_; ++i) {
|
||||
if (i * dim_num + 1 >= indices_length) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'indices' out of bounds.";
|
||||
}
|
||||
if (i >= values_length) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'values' out of bounds.";
|
||||
}
|
||||
|
||||
const int row = a_indices[i * dim_num];
|
||||
const int col = a_indices[i * dim_num + 1];
|
||||
if (row >= SizeToInt(out_dim_0) || row < 0 || col >= SizeToInt(same_dim) || col < 0) {
|
||||
MS_EXCEPTION(ValueError) << "The indices including out of bounds index, row range: [0, " << out_dim_0
|
||||
<< "), col range: [0, " << same_dim << "), but got row: " << row << ", col: " << col;
|
||||
}
|
||||
|
||||
const size_t row_s = IntToSize(row);
|
||||
const size_t col_s = IntToSize(col);
|
||||
for (size_t n = 0; n < out_dim_1; ++n) {
|
||||
if (col_s * b_dim_1 + n >= b_length) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of 'b' out of bounds.";
|
||||
}
|
||||
const T b_value = b[col_s * b_dim_1 + n];
|
||||
out[row_s * out_dim_1 + n] += *(alpha)*a_values[i] * b_value;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < out_dim_0; ++i) {
|
||||
for (size_t j = 0; j < out_dim_1; ++j) {
|
||||
const T c_value = c[i * out_dim_1 + j];
|
||||
out[i * out_dim_1 + j] += *(beta)*c_value;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SparseAddmmCpuKernelMod::SparseAddmmFunc>> SparseAddmmCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int32_t, uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&SparseAddmmCpuKernelMod::LaunchKernel<int64_t, uint64_t>}};
|
||||
|
||||
std::vector<KernelAttr> SparseAddmmCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SparseAddmmFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SparseAddmm, SparseAddmmCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPARSE_ADDMM_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SparseAddmmCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SparseAddmmCpuKernelMod() = default;
|
||||
~SparseAddmmCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using SparseAddmmFunc = std::function<bool(SparseAddmmCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, SparseAddmmFunc>> func_list_;
|
||||
SparseAddmmFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> b_shape_;
|
||||
size_t output_size_{0};
|
||||
size_t values_size_{0};
|
||||
enum input_list_ { INDICES, VALUES, SPARSE_SHAPE, DENSE, DENSE_X2, ALPHA, BETA };
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RMSPROP_CPU_KERNEL_H_
|
|
@ -288,6 +288,7 @@ constexpr auto kCOOTensorDenseMatmul = "COOTensorDenseMatmul";
|
|||
|
||||
// Sparse ops
|
||||
constexpr auto kSparseTensorDenseMatmul = "SparseTensorDenseMatmul";
|
||||
constexpr auto kSparseAddmm = "SparseAddmm";
|
||||
constexpr auto kCSRReduceSum = "CSRReduceSum";
|
||||
constexpr auto kCSRMV = "CSRMV";
|
||||
constexpr auto kCSRMM = "CSRMM";
|
||||
|
@ -923,6 +924,7 @@ GVAR_DEF(PrimitivePtr, kPrimCSRTensorGetDenseShape, std::make_shared<Primitive>(
|
|||
|
||||
// Sparse ops
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseTensorDenseMatmul, std::make_shared<Primitive>(kSparseTensorDenseMatmul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseAddmm, std::make_shared<Primitive>(kSparseAddmm));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCOOTensorDenseMatmul, std::make_shared<Primitive>(kCOOTensorDenseMatmul));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRReduceSum, std::make_shared<Primitive>(kCSRReduceSum));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCSRMV, std::make_shared<Primitive>(kCSRMV));
|
||||
|
|
|
@ -0,0 +1,124 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "ops/sparse_addmm.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr SparseAddmmInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto indices_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto values_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto x2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
||||
const int kDimensionTwo = 2;
|
||||
const int kDimensionOne = 1;
|
||||
if (indices_shape.size() != kDimensionTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input indices should "
|
||||
<< "have rank 2, but got " << indices_shape.size() << ".";
|
||||
}
|
||||
if (indices_shape[1] != kDimensionTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 2nd dimension of indices "
|
||||
<< "should be 2, but got " << indices_shape[1] << ".";
|
||||
}
|
||||
if (values_shape.size() != kDimensionOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input value should "
|
||||
<< "have rank 1, but got " << values_shape.size() << ".";
|
||||
}
|
||||
if (shape_shape.size() != kDimensionOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
|
||||
<< "have rank 1, but got " << shape_shape.size() << ".";
|
||||
}
|
||||
if (shape_shape[0] != kDimensionTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the 1st dimension of input shape "
|
||||
<< "should be 2, but got " << shape_shape[0] << ".";
|
||||
}
|
||||
if (x2_shape.size() != kDimensionTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of input dense "
|
||||
<< "should be [2], but got [" << x2_shape.size() << "].";
|
||||
}
|
||||
if (x3_shape.size() != kDimensionTwo) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the shape of input dense "
|
||||
<< "should be [2], but got [" << x3_shape.size() << "].";
|
||||
}
|
||||
if (alpha_shape.size() != kDimensionOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
|
||||
<< "have rank 1, but got " << alpha_shape.size() << ".";
|
||||
}
|
||||
if (beta_shape.size() != kDimensionOne) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', the input shape should "
|
||||
<< "have rank 1, but got " << beta_shape.size() << ".";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(x3_shape);
|
||||
}
|
||||
TypePtr SparseAddmmInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
std::map<std::string, TypePtr> types;
|
||||
std::set<TypePtr> valid_types = {kFloat32, kFloat64, kInt32, kInt64, kInt16,
|
||||
kInt8, kUInt32, kUInt64, kUInt16, kUInt8};
|
||||
TypePtr indices_type = input_args[kInputIndex0]->BuildType();
|
||||
TypePtr values_type = input_args[kInputIndex1]->BuildType();
|
||||
TypePtr shape_type = input_args[kInputIndex2]->BuildType();
|
||||
TypePtr x2_type = input_args[kInputIndex3]->BuildType();
|
||||
TypePtr x3_type = input_args[kInputIndex4]->BuildType();
|
||||
TypePtr alpha_type = input_args[kInputIndex5]->BuildType();
|
||||
TypePtr beta_type = input_args[kInputIndex6]->BuildType();
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
(void)types.emplace("x1_values", values_type);
|
||||
(void)types.emplace("x2", x2_type);
|
||||
(void)types.emplace("x3", x3_type);
|
||||
(void)types.emplace("alpha", alpha_type);
|
||||
(void)types.emplace("beta", beta_type);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, primitive->name());
|
||||
|
||||
const std::set<TypePtr> valid_type = {kInt64, kInt32};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("indices", indices_type, valid_type, prim_name);
|
||||
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sparse_shape", shape_type, valid_type, prim_name);
|
||||
|
||||
auto tensor_type = x3_type->cast<TensorTypePtr>();
|
||||
auto tensor_element = tensor_type->element();
|
||||
|
||||
return tensor_element;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 7;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
// infer type
|
||||
auto infer_type = SparseAddmmInferType(primitive, input_args);
|
||||
// infer shape
|
||||
auto infer_shape = SparseAddmmInferShape(primitive, input_args);
|
||||
return std::make_shared<abstract::AbstractTensor>(infer_type, infer_shape);
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(SparseAddmm, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SparseAddmm, prim::kPrimSparseAddmm, SparseAddmmInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_
|
||||
#define MINDSPORE_CORE_OPS_SPARSE_ADDMM_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSparseAddmm = "SparseAddmm";
|
||||
class MIND_API SparseAddmm : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(SparseAddmm);
|
||||
SparseAddmm() : BaseOperator(kNameSparseAddmm) {
|
||||
InitIOName({"indices", "values", "sparse_shape", "x2_dense", "x3_dense", "alpha", "beta"}, {"output"});
|
||||
}
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr SparseAddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using PrimSparseAddmmPtr = std::shared_ptr<SparseAddmm>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif
|
|
@ -98,6 +98,7 @@ from .random_poisson import _random_poisson_aicpu
|
|||
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
||||
from .rsqrt import _rsqrt_aicpu
|
||||
from .rsqrt_grad import _rsqrt_grad_aicpu
|
||||
from .sparseaddmm import _sparse_addmm_aicpu
|
||||
from .search_sorted import _search_sorted_aicpu
|
||||
from .stack import _stack_aicpu
|
||||
from .unstack import _unstack_aicpu
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SparseAddmm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sparseaddmm_op_info = AiCPURegOp("SparseAddmm") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1_indices", "required") \
|
||||
.input(1, "x1_values", "required") \
|
||||
.input(2, "x1_shape", "required") \
|
||||
.input(3, "x2", "required") \
|
||||
.input(4, "x3", "required") \
|
||||
.input(5, "alpha", "required") \
|
||||
.input(6, "beta", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I8_Default, DataType.I32_Default, DataType.I8_Default,
|
||||
DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I8_Default, DataType.I64_Default, DataType.I8_Default,
|
||||
DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I16_Default, DataType.I32_Default, DataType.I16_Default,
|
||||
DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I16_Default, DataType.I64_Default, DataType.I16_Default,
|
||||
DataType.I16_Default, DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default, DataType.I32_Default,
|
||||
DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default,
|
||||
DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U8_Default, DataType.I32_Default, DataType.U8_Default,
|
||||
DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U8_Default, DataType.I64_Default, DataType.U8_Default,
|
||||
DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U16_Default, DataType.I32_Default, DataType.U16_Default,
|
||||
DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U16_Default, DataType.I64_Default, DataType.U16_Default,
|
||||
DataType.U16_Default, DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U32_Default, DataType.I32_Default, DataType.U32_Default,
|
||||
DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U32_Default, DataType.I64_Default, DataType.U32_Default,
|
||||
DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.U64_Default, DataType.I32_Default, DataType.U64_Default,
|
||||
DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.I64_Default, DataType.U64_Default,
|
||||
DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.I32_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.I64_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.I64_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default,
|
||||
DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.I64_Default, DataType.F64_Default,
|
||||
DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.C64_Default, DataType.I32_Default, DataType.C64_Default,
|
||||
DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.C64_Default, DataType.I64_Default, DataType.C64_Default,
|
||||
DataType.C64_Default, DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.C128_Default, DataType.I32_Default, DataType.C128_Default,
|
||||
DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.C128_Default, DataType.I64_Default, DataType.C128_Default,
|
||||
DataType.C128_Default, DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sparseaddmm_op_info)
|
||||
def _sparse_addmm_aicpu():
|
||||
"""SparseAddmm AiCPU register"""
|
||||
return
|
|
@ -1,6 +1,6 @@
|
|||
# coding: utf-8
|
||||
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -531,6 +531,58 @@ class Sspaddmm(Primitive):
|
|||
'x3_dense', 'alpha', 'beta'], outputs=['y_indices', 'y_values', 'y_shape'])
|
||||
|
||||
|
||||
class SparseAddmm(Primitive):
|
||||
"""
|
||||
Multiplies sparse matrix `A` by dense matrix `B` * `alpha` and add dense matrix `C` * `beta`.
|
||||
The rank of sparse matrix and dense matrix must equal to `2`.
|
||||
|
||||
Inputs:
|
||||
- **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
|
||||
- **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
|
||||
Support float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64.
|
||||
The shape should be :math:`(n,)`.
|
||||
- **sparse_shape** (Tensor) - A positive int tuple which specifies the shape of sparse tensor.
|
||||
Support int32, int64, should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
|
||||
- **x2_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
|
||||
- **x3_dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
|
||||
- **alpha** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
|
||||
- **beta** (Tensor) - A 1-D Tensor, the dtype is same as `values`.
|
||||
|
||||
Outputs:
|
||||
Tensor, the dtype is the same as `values`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `indices`, dtype of `values` and dtype of `dense` don't meet the parameter description.
|
||||
ValueError: If `sparse_shape`, shape of `indices, shape of `values`, and shape of `dense` don't meet the
|
||||
parameter description.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
|
||||
>>> values = Tensor([1, 2], dtype=ms.float32)
|
||||
>>> sparse_shape = Tensor([1, 2], dtype=ms.int32)
|
||||
>>> x2_dense = Tensor([[1,1], [2,2], [3,3], [4,4]], dtype=ms.float32)
|
||||
>>> x3_dense = Tensor([[2,2], [6,6], [0,0]], dtype=ms.float32)
|
||||
>>> alpha = Tensor([1], dtype=ms.float32)
|
||||
>>> beta = Tensor([1], dtype=ms.float32)
|
||||
>>> sparse_addmm = ops.SparseAddmm()
|
||||
>>> out = sparse_addmm(indices, values, sparse_shape, x2_dense, x3_dense, alpha, beta)
|
||||
>>> print(out)
|
||||
[[4 4]
|
||||
[12 12]
|
||||
[0 0]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize SparseAddmm"""
|
||||
self.init_prim_io_names(inputs=['indices', 'values', 'sparse_shape', 'x2_dense', 'x3_dense', 'alpha', 'beta'],
|
||||
outputs=['output'])
|
||||
|
||||
|
||||
class SparseConcat(Primitive):
|
||||
"""
|
||||
concatenates the input SparseTensor(COO format) along the specified dimension. demo API now
|
||||
|
|
|
@ -224,6 +224,9 @@ class InputOpNet(nn.Cell):
|
|||
x = self.op(x1, x2, x3, x4, x5, self.c1, self.c2, self.c3, self.c4)
|
||||
return x
|
||||
|
||||
def construct7_c0(self, x1, x2, x3, x4, x5, x6, x7):
|
||||
x = self.op(x1, x2, x3, x4, x5, x6, x7)
|
||||
return x
|
||||
|
||||
def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
|
||||
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9)
|
||||
|
|
|
@ -125,6 +125,7 @@ from mindspore.ops.operations.sparse_ops import SparseMatrixNNZ
|
|||
from mindspore.ops.operations.sparse_ops import SparseTensorDenseAdd
|
||||
from mindspore.ops.operations.sparse_ops import SparseMatrixTranspose
|
||||
from mindspore.ops.operations.sparse_ops import CSRSparseMatrixToSparseTensor
|
||||
from mindspore.ops.operations.sparse_ops import SparseAddmm
|
||||
from mindspore.ops.operations.sparse_ops import SparseTensorToCSRSparseMatrix
|
||||
from mindspore.ops.operations.sparse_ops import SparseSparseMinimum
|
||||
from mindspore.ops.operations.sparse_ops import SparseSegmentSqrtN
|
||||
|
@ -3227,6 +3228,16 @@ test_case_array_ops = [
|
|||
'block': UnravelIndex(),
|
||||
'desc_inputs': [Tensor(np.array([5, 5]).astype(np.int64)), Tensor(np.array([3, 3]).astype(np.int64))],
|
||||
'skip': ['backward']}),
|
||||
('SparseAddmm', {
|
||||
'block': SparseAddmm(),
|
||||
'desc_inputs': [Tensor(np.array([[0, 1], [1, 2]]).astype(np.int32)),
|
||||
Tensor(np.array([1, 2]).astype(np.int32)),
|
||||
Tensor(np.array([3, 4]).astype(np.int32)),
|
||||
Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.int32)),
|
||||
Tensor(np.array([[2, 2], [6, 6], [0, 0]]).astype(np.int32)),
|
||||
Tensor(np.array([1]).astype(np.int32)),
|
||||
Tensor(np.array([1]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('SpaceToDepth', {
|
||||
'block': P.SpaceToDepth(2),
|
||||
'desc_inputs': [[1, 3, 2, 2]],
|
||||
|
|
Loading…
Reference in New Issue