From 16aad8be1b4ef53b790ef9237611a5483fe1308b Mon Sep 17 00:00:00 2001 From: Lan-Ling Date: Tue, 14 Jun 2022 12:52:25 +0800 Subject: [PATCH] update aicpu operator sspaddmm --- .../device/cpu/kernel/sspaddmm_cpu_kernel.cc | 371 ++++++++++++++ .../device/cpu/kernel/sspaddmm_cpu_kernel.h | 72 +++ .../core/abstract/ops/primitive_infer_map.cc | 4 +- mindspore/core/ops/core_ops.h | 1 + mindspore/core/ops/sspaddmm.cc | 484 ++++++++++++++++++ mindspore/core/ops/sspaddmm.h | 49 ++ .../mindspore/ops/_op_impl/aicpu/__init__.py | 1 + .../mindspore/ops/_op_impl/aicpu/sspaddmm.py | 97 ++++ .../mindspore/ops/_op_impl/cpu/__init__.py | 1 + .../mindspore/ops/_op_impl/cpu/sspaddmm.py | 95 ++++ .../mindspore/ops/operations/sparse_ops.py | 100 ++++ .../utils/block_util.py | 5 + tests/ut/python/ops/test_ops.py | 14 +- 13 files changed, 1292 insertions(+), 2 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h create mode 100644 mindspore/core/ops/sspaddmm.cc create mode 100644 mindspore/core/ops/sspaddmm.h create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/sspaddmm.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/cpu/sspaddmm.py diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.cc new file mode 100644 index 00000000000..600942deb37 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.cc @@ -0,0 +1,371 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h" +#include +#include +#include +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kInputsNum = 9; +constexpr size_t kOutputsNum = 3; +constexpr char kKernelName[] = "Sspaddmm"; + +#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \ + if (dtype1 == kNumberTypeInt32) { \ + CheckSparseIndicesLegal(indices, shapes, num, x_name); \ + } else { \ + CheckSparseIndicesLegal(indices, shapes, num, x_name); \ + } +} // namespace + +void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName); + size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); + CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kKernelName); +} + +void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + CheckParam(kernel_node); + + output_values_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1); + input_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0); + input_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2); + mat1_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex3); + mat1_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex5); + alpha_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex7); + beta_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex8); + + auto input_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0); + auto mat1_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex3); + auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0); + auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6); + + input_values_num_ = input_indices_shape[1]; + mat1_values_num_ = mat1_indices_shape[1]; + y_values_num_ = y_indices_shape[1]; + mat2_row_ = mat2_shape[0]; + mat2_col_ = mat2_shape[1]; +} + +bool SspaddmmCPUKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + switch (output_values_dtype_) { + case kNumberTypeUInt8: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeInt8: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeInt16: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeInt32: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeInt64: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeFloat32: { + LaunchKernel(inputs, outputs); + break; + } + case kNumberTypeFloat64: { + LaunchKernel(inputs, outputs); + break; + } + default: { + MS_EXCEPTION(TypeError) << "For Sspaddmm, The output dtype error."; + } + } + return true; +} + +template +void SspaddmmCPUKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto input_indices_addr = inputs[0]->addr; + auto input_values_addr = reinterpret_cast(inputs[1]->addr); + auto input_shape_addr = inputs[2]->addr; + auto mat1_indices_addr = inputs[3]->addr; + auto mat1_values_addr = reinterpret_cast(inputs[4]->addr); + auto mat1_shape_addr = inputs[5]->addr; + auto mat2_addr = reinterpret_cast(inputs[6]->addr); + auto alpha_val_addr = inputs[7]->addr; + auto beta_val_addr = inputs[8]->addr; + auto y_indices_addr = reinterpret_cast(outputs[0]->addr); + auto y_values_addr = reinterpret_cast(outputs[1]->addr); + auto y_shape_addr = reinterpret_cast(outputs[2]->addr); + CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_, + "x1"); + CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_, + "x2"); + + int64_t mat1_row, mat1_col, input_row, input_col; + if (mat1_shape_dtype_ == kNumberTypeInt32) { + auto mat1_shape_val = reinterpret_cast(mat1_shape_addr); + mat1_row = static_cast(mat1_shape_val[0]); + mat1_col = static_cast(mat1_shape_val[1]); + } else { + auto mat1_shape_val = reinterpret_cast(mat1_shape_addr); + mat1_row = mat1_shape_val[0]; + mat1_col = mat1_shape_val[1]; + } + if (input_shape_dtype_ == kNumberTypeInt32) { + auto input_shape_val = reinterpret_cast(input_shape_addr); + input_row = static_cast(input_shape_val[0]); + input_col = static_cast(input_shape_val[1]); + } else { + auto input_shape_val = reinterpret_cast(input_shape_addr); + input_row = input_shape_val[0]; + input_col = input_shape_val[1]; + } + + if (mat1_col != static_cast(mat2_row_)) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse tensor x2's dense shape col:" << mat1_col + << " should be equal to x3_dense shape row:" << mat2_row_ << "."; + } + if (mat1_row != input_row || static_cast(mat2_col_) != input_col) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse x1's dense shape " + << "[" << input_row << "," << input_col << "] should equal to x2@x3_dense shape [" + << mat1_row << "," << mat2_col_ << "]."; + } + + if (input_shape_dtype_ == kNumberTypeInt32) { + InitShape(input_shape_addr, y_shape_addr); + } else { + InitShape(input_shape_addr, y_shape_addr); + } + + ClearSparseValues(y_values_addr, y_values_num_); + + // scalar * sparse inplace + T *input_values_addr_bak = ScalarSparseMul(input_values_addr, beta_val_addr, input_values_num_, beta_dtype_); + T *mat1_values_addr_bak = ScalarSparseMul(mat1_values_addr, alpha_val_addr, mat1_values_num_, alpha_dtype_); + + // sparse + sparse + if (input_indices_dtype_ == kNumberTypeInt32) { + SparseAddSparse(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr, + y_values_addr, y_values_num_); + } else { + SparseAddSparse(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr, + y_values_addr, y_values_num_); + } + auto row = mat1_col; + auto col = input_col; + if (mat1_indices_dtype_ == kNumberTypeInt32) { + SparseMulDense(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr, + y_values_addr, y_values_num_, row, col); + } else { + SparseMulDense(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr, + y_values_addr, y_values_num_, row, col); + } +} + +template +void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, + std::string x_name) { + auto indices_val = reinterpret_cast(indices_addr); + auto shape_val = reinterpret_cast(shape_addr); + int shape_num = 2; + for (int i = 0; i < shape_num; i++) { + if (shape_val[i] <= 0) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_value should be positive" + << " while get shape [" << shape_val[0] << ", " << shape_val[1] << "]"; + } + } + for (size_t i = 0; i < num; i++) { + int64_t row = static_cast(shape_val[0]); + int64_t col = static_cast(shape_val[1]); + int64_t indices_row = static_cast(indices_val[i]); + int64_t indices_col = static_cast(indices_val[i + num]); + if ((indices_row >= row) || indices_col >= col || indices_row < 0 || indices_col < 0) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_indices" + << " row value:" << indices_row << ", col value: " << indices_col << " out of bounds." + << " Row should between [0," << row << "].Col should between [0," << col << "]."; + } + } +} + +template +void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) { + auto input_shape_val = reinterpret_cast(input_shape); + size_t shape_num = 2; + for (size_t i = 0; i < shape_num; i++) { + y_shape[i] = static_cast(input_shape_val[i]); + } +} + +template +void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) { + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + sparse_val[i] = static_cast(0); + } + }; + + ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_); +} + +// scalar * sparse matrix for beta * input alpha * mat1 +template +T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) { + T val; + if (!(data_num > 0)) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. "; + } + T *sparse_val_bak = new T[data_num]; + switch (tid) { + case kNumberTypeUInt8: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeUInt16: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeUInt32: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeUInt64: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeInt8: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeInt16: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeInt32: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeInt64: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeFloat16: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeFloat32: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeFloat64: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeBool: + val = static_cast(reinterpret_cast(scalar_val)[0]); + break; + case kNumberTypeComplex64: + val = static_cast(reinterpret_cast *>(scalar_val)[0].real()); + break; + case kNumberTypeComplex128: + val = static_cast(reinterpret_cast *>(scalar_val)[0].real()); + break; + default: + MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. "; + break; + } + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + sparse_val_bak[i] = sparse_val[i] * val; + } + }; + ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_); + return sparse_val_bak; +} + +// sparse matrix add sparse matrix +// input + mat1 @ mat2 +template +void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices, + S *y_values, size_t y_num) { + // to implement m1[row][col] = vals + auto input_ids = reinterpret_cast(input_indices); + this->cnt_ = input_num; + // get output vals and index addr + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + auto row = input_ids[i]; + auto col = input_ids[i + input_num]; + y_values[i] = input_values[i]; + y_indices[i] = static_cast(row); + y_indices[i + y_num] = static_cast(col); + } + }; + ParallelLaunchAutoSearch(task, input_num, this, ¶llel_search_info_); +} + +template +void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr, + int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, + int64_t mat2_col_) { + // the result of mat1 @ mat2 will write to output directly + auto mat1_ids = reinterpret_cast(mat1_indices); + + std::unordered_map> idx_map_cnt; + std::unordered_map> unrepeated; + std::unordered_map>> co_map_idx; + + // unrepeated : [1 -> [0], 2 -> [1, 2]] + // co_map_idx : [1][0] -> 0.3 + for (size_t i = 0; i < mat1_vals_num; i++) { + T _row = mat1_ids[i]; + T _col = mat1_ids[i + mat1_vals_num]; + unrepeated[_row].push_back(_col); + co_map_idx[_row][_col].push_back(mat1_values[i]); + for (int64_t j = 0; j < mat2_col_; j++) { + if (idx_map_cnt[_row][j] == 0) { + idx_map_cnt[_row][j] = this->cnt_; + this->cnt_++; + } + } + } + + std::vector res; + for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) { + res.push_back(it->first); + } + + size_t n_unreapeat = unrepeated.size(); + auto task = [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + // get val + auto row_mat1 = res[i]; + for (auto row_mat2 : unrepeated[row_mat1]) { + S val = co_map_idx[row_mat1][row_mat2].back(); + co_map_idx[row_mat1][row_mat2].pop_back(); + for (int64_t j = 0; j < mat2_col_; j++) { + // get val + T idx = idx_map_cnt[row_mat1][j]; + *(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j]; + y_indices[idx] = static_cast(row_mat1); + y_indices[idx + y_vals_num] = j; + } + } + } + }; + ParallelLaunchAutoSearch(task, n_unreapeat, this, ¶llel_search_info_); +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Sspaddmm, SspaddmmCPUKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h new file mode 100644 index 00000000000..665815c8b53 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_ + +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod { + public: + SspaddmmCPUKernelMod() = default; + ~SspaddmmCPUKernelMod() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + void CheckParam(const CNodePtr &kernel_node); + template + void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name); + template + void InitShape(void *input_shape, int64_t *y_shape); + template + void ClearSparseValues(T *sparse_val, size_t data_num); + template + T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid); + template + void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values, + size_t y_num); + template + void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor, + int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col); + + TypeId output_values_dtype_{kTypeUnknown}; + TypeId input_indices_dtype_{kTypeUnknown}; + TypeId input_shape_dtype_{kTypeUnknown}; + TypeId mat1_indices_dtype_{kTypeUnknown}; + TypeId mat1_shape_dtype_{kTypeUnknown}; + TypeId alpha_dtype_{kTypeUnknown}; + TypeId beta_dtype_{kTypeUnknown}; + size_t input_values_num_{0}; + size_t mat1_values_num_{0}; + size_t y_values_num_{0}; + size_t cnt_{0}; + size_t mat2_row_{0}; + size_t mat2_col_{0}; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_ diff --git a/mindspore/core/abstract/ops/primitive_infer_map.cc b/mindspore/core/abstract/ops/primitive_infer_map.cc index a52d691ac70..0597a6a7020 100644 --- a/mindspore/core/abstract/ops/primitive_infer_map.cc +++ b/mindspore/core/abstract/ops/primitive_infer_map.cc @@ -91,6 +91,7 @@ PrimShapeDependMap &GetHostDependsMap() { static const auto &kSegmentSum = prim::kPrimSegmentSum->name(); static const auto &kBlackmanWindow = prim::kPrimBlackmanWindow->name(); static const auto &kExpand = prim::kPrimExpand->name(); + static const auto &kSspaddmm = prim::kPrimSspaddmm->name(); // Common host depends. static PrimShapeDependMap host_depends{{kSegmentMax, ShapeSet{1}}, {kSegmentMin, ShapeSet{1}}, @@ -132,7 +133,8 @@ PrimShapeDependMap &GetHostDependsMap() { {kSparseSegmentMean, ShapeSet{2}}, {kResizeLinear1D, ShapeSet{1}}, {kBlackmanWindow, ShapeSet{0}}, - {kExpand, ShapeSet{1}}}; + {kExpand, ShapeSet{1}}, + {kSspaddmm, ShapeSet{0, 2, 3, 5, 7}}}; return host_depends; } diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 8653b3d8114..82b0814c2ee 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -340,6 +340,7 @@ GVAR_DEF(PrimitivePtr, kPrimGatherNd, std::make_shared("GatherNd")); GVAR_DEF(PrimitivePtr, kPrimSparseGatherV2, std::make_shared("SparseGatherV2")); GVAR_DEF(PrimitivePtr, kPrimCoalesce, std::make_shared(kCoalesce)); GVAR_DEF(PrimitivePtr, kPrimSparseToDense, std::make_shared("SparseToDense")); +GVAR_DEF(PrimitivePtr, kPrimSspaddmm, std::make_shared("Sspaddmm")); GVAR_DEF(PrimitivePtr, kPrimShape, std::make_shared("Shape")); GVAR_DEF(PrimitivePtr, kPrimStridedSlice, std::make_shared(kStridedSlice)); GVAR_DEF(PrimitivePtr, kPrimStridedSliceGrad, std::make_shared(kStridedSliceGrad)); diff --git a/mindspore/core/ops/sspaddmm.cc b/mindspore/core/ops/sspaddmm.cc new file mode 100644 index 00000000000..b3a897acd69 --- /dev/null +++ b/mindspore/core/ops/sspaddmm.cc @@ -0,0 +1,484 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "ops/sspaddmm.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +const int64_t MAX_LEN = 1000000; + +int64_t compute_output_indices_unique_size_int32(int32_t *indices, size_t size) { + std::set mat1_indices_set; + size_t half_size = size / 2; + for (size_t i = 0; i < half_size; i++) { + mat1_indices_set.insert(indices[i]); + } + return mat1_indices_set.size(); +} + +int64_t compute_output_indices_unique_size_int64(int64_t *indices, size_t size) { + std::set mat1_indices_set; + size_t half_size = size / 2; + for (size_t i = 0; i < half_size; i++) { + mat1_indices_set.insert(indices[i]); + } + return mat1_indices_set.size(); +} + +enum DimNum : size_t { + dim0Num = 0, + dim1Num, + dim2Num, +}; + +int64_t GetInt64AlphaDataOther(void *values, TypeId tid, TypePtr expect_dtype, float real) { + int64_t compute_val = 0; + bool flag1 = expect_dtype->type_id() == kNumberTypeUInt8 ? true : false; + bool flag2 = false; + switch (tid) { + case kNumberTypeFloat16: + flag2 = (flag1 && (reinterpret_cast(values)[0] < static_cast(0))) ? true : false; + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeFloat32: + flag2 = (flag1 && (reinterpret_cast(values)[0] < 0)) ? true : false; + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeFloat64: + flag2 = (flag1 && (reinterpret_cast(values)[0] < 0)) ? true : false; + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeBool: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeComplex64: + case kNumberTypeComplex128: + compute_val = static_cast(real); + break; + default: + MS_EXCEPTION(TypeError) << "For Sspaddmm, alpha dtype is not support, only support" + << " number type and bool, complex64, complex128. "; + break; + } + if (flag2) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cannot be converted to type uint8 without overflow. "; + } + return compute_val; +} + +int64_t GetInt64AlphaData(void *values, TypeId tid, TypePtr expect_dtype, float real) { + int64_t compute_val = 0; + switch (tid) { + case kNumberTypeUInt8: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeUInt16: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeUInt32: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeUInt64: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeInt8: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeInt16: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeInt32: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + case kNumberTypeInt64: + compute_val = static_cast(reinterpret_cast(values)[0]); + break; + default: + compute_val = GetInt64AlphaDataOther(values, tid, expect_dtype, real); + break; + } + return compute_val; +} + +void CheckAlphaBeta(const std::vector &input_args) { + auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape]; + auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->BuildShape())[kShape]; + if (!(alpha_shape.size() == dim1Num && alpha_shape[0] == dim1Num)) { + if (alpha_shape.size() != dim0Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha shape should be (1,) or ()" + << ", but get dim num is " << alpha_shape.size() << ", dim0 size is " << alpha_shape[0] + << "."; + } + } + if (!(beta_shape.size() == dim1Num && beta_shape[0] == dim1Num)) { + if (beta_shape.size() != dim0Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, beta shape should be (1,) or ()" + << ", but get dim num is " << beta_shape.size() << ", dim0 size is " << beta_shape[0] + << "."; + } + } +} + +void CheckInputTensor(const std::vector &input_args) { + auto x1_indices_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto x1_values_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + auto x1_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; + auto x2_indices_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape]; + auto x2_values_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape]; + auto x2_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape]; + auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape]; + + if (x1_indices_shape.size() != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices should be a 2-D tensor" + << ", while x1_indices dim num is " << x1_indices_shape.size() << "."; + } + if (x1_indices_shape[0] != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices shape should be (2, n)" + << ", while x1_indices shape dim0 is " << x1_indices_shape[0] << "."; + } + if (x1_values_shape.size() != dim1Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_values should be a 1-D tensor" + << ", while x1_values dim num is " << x1_values_shape.size() << "."; + } + if (x1_indices_shape[1] != x1_values_shape[0]) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", dim1 size of `x1_indices` and dim0 size of `x1_values` should be the same" + << " while x1_indices dim1 size is " << x1_indices_shape[1] + << ", x1_values_shape dim0 size is " << x1_values_shape[0] << "."; + } + if (x1_shape_shape.size() != dim1Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", x1_shape should be a 1-D tensor, while x1_shape dim num is " << x1_shape_shape.size() + << "."; + } + if (x1_shape_shape[0] != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", the shape of x1_shape should be [2] but got shape [" << x1_shape_shape[0] << "]."; + } + if (x2_indices_shape.size() != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices should be a 2-D tensor" + << ", while x2_indices dim num is " << x2_indices_shape.size() << "."; + } + if (x2_indices_shape[0] != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices shape should be (2, n)" + << ", while x2_indices shape dim0 is " << x2_indices_shape[0] << "."; + } + if (x2_values_shape.size() != dim1Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_values should be a 1-D tensor" + << ", while x2_values dim num is " << x2_values_shape.size() << "."; + } + if (x2_indices_shape[1] != x2_values_shape[0]) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", dim1 size of `x2_indices` and dim0 size of `x2_values` should be the same" + << " while x2_indices dim1 size is " << x2_indices_shape[1] + << ", x2_values_shape dim0 size is " << x2_values_shape[0] << "."; + } + if (x2_shape_shape.size() != dim1Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", x2_shape should be a 1-D tensor, while x2_shape dim num is " << x2_shape_shape.size() + << "."; + } + if (x2_shape_shape[0] != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", the shape of x2_shape should be [2] but got shape [" << x2_shape_shape[0] << "]."; + } + if (x3_shape.size() != dim2Num) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, x3_dense should be a 2-D tensor" + << ", while dim num is " << x3_shape.size() << "."; + } + CheckAlphaBeta(input_args); +} + +template +void IndicesBoundCheck(T *indices_val, size_t indices_num, T *shape_val, std::string name) { + if (shape_val[0] <= 0 || shape_val[1] <= 0) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_shape should be positive, " + << "while got shape [" << shape_val[0] << ", " << shape_val[1] << "]."; + } + size_t half_num = indices_num / dim2Num; + for (size_t i = 0; i < half_num; i++) { + if ((indices_val[i] < 0) || (indices_val[i] >= shape_val[0])) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices row index should between [0, " << shape_val[0] + << "], while got row index " << indices_val[i] << "."; + } + if ((indices_val[i + half_num] < 0) || (indices_val[i + half_num] >= shape_val[1])) { + MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices col index should between [0, " << shape_val[1] + << "], while got col index " << indices_val[i + half_num] << "."; + } + } +} + +void CheckIndices(const std::vector &input_args) { + auto x1_indices_abstract = input_args[kInputIndex0]->cast(); + MS_EXCEPTION_IF_NULL(x1_indices_abstract); + auto x1_indices_value_ptr = x1_indices_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(x1_indices_value_ptr); + auto x1_indices_tensor = x1_indices_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(x1_indices_tensor); + auto x1_indices_type = input_args[kInputIndex0]->BuildType(); + MS_EXCEPTION_IF_NULL(x1_indices_type); + auto x1_indices_type_id = x1_indices_type->cast(); + MS_EXCEPTION_IF_NULL(x1_indices_type_id); + auto x1_indices_type_element = x1_indices_type_id->element(); + MS_EXCEPTION_IF_NULL(x1_indices_type_element); + auto x1_shape_abstract = input_args[kInputIndex2]->cast(); + MS_EXCEPTION_IF_NULL(x1_shape_abstract); + auto x1_shape_value_ptr = x1_shape_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(x1_shape_value_ptr); + auto x1_shape_tensor = x1_shape_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(x1_shape_tensor); + if (x1_indices_type_element->type_id() == kNumberTypeInt32) { + IndicesBoundCheck(reinterpret_cast(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(), + reinterpret_cast(x1_shape_tensor->data_c()), "x1"); + } else { + IndicesBoundCheck(reinterpret_cast(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(), + reinterpret_cast(x1_shape_tensor->data_c()), "x1"); + } + auto x2_indices_abstract = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_abstract); + auto x2_indices_value_ptr = x2_indices_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(x2_indices_value_ptr); + auto x2_indices_tensor = x2_indices_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_tensor); + auto x2_indices_type = input_args[kInputIndex3]->BuildType(); + MS_EXCEPTION_IF_NULL(x2_indices_type); + auto x2_indices_type_id = x2_indices_type->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_type_id); + auto x2_indices_type_element = x2_indices_type_id->element(); + MS_EXCEPTION_IF_NULL(x2_indices_type_element); + auto x2_shape_abstract = input_args[kInputIndex5]->cast(); + MS_EXCEPTION_IF_NULL(x2_shape_abstract); + auto x2_shape_value_ptr = x2_shape_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(x2_shape_value_ptr); + auto x2_shape_tensor = x2_shape_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(x2_shape_tensor); + if (x2_indices_type_element->type_id() == kNumberTypeInt32) { + IndicesBoundCheck(reinterpret_cast(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(), + reinterpret_cast(x2_shape_tensor->data_c()), "x2"); + } else { + IndicesBoundCheck(reinterpret_cast(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(), + reinterpret_cast(x2_shape_tensor->data_c()), "x2"); + } +} + +bool GetDtypeMinAndMaxAndCheckOverFlow(TypePtr tid, int64_t compute_val) { + int64_t min = 0; + int64_t max = 0; + switch (tid->type_id()) { + case kNumberTypeUInt8: + max = UCHAR_MAX; + min = -UCHAR_MAX - 1; + break; + case kNumberTypeInt8: + max = SCHAR_MAX; + min = SCHAR_MIN; + break; + case kNumberTypeInt16: + max = SHRT_MAX; + min = SHRT_MIN; + break; + case kNumberTypeInt32: + max = INT_MAX; + min = INT_MIN; + break; + default: + max = LONG_MAX; + min = LONG_MIN; + break; + } + if (compute_val <= min || compute_val > max) { + return true; + } else { + return false; + } +} + +void PrintAlphaValueError(TypeId aid, TypePtr expect_dtype, int64_t compute_val, float real, int64_t imag) { + if (aid == kNumberTypeComplex64 || aid == kNumberTypeComplex128) { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", alpha cannot be converted to expect dtype " << expect_dtype->ToString() + << ", without overflow: (" << real << ", " << imag << ")."; + } else { + MS_EXCEPTION(ValueError) << "For Sspaddmm" + << ", alpha cannot be converted to expect x2_values dtype " << expect_dtype->ToString() + << ", without overflow: " << compute_val << "."; + } +} + +abstract::TupleShapePtr SspaddmmInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + if (input_args[kInputIndex3]->isa() && + !input_args[kInputIndex3]->BuildValue()->isa() && + !input_args[kInputIndex3]->BuildValue()->isa()) { + if (input_args[kInputIndex7]->isa() && + !input_args[kInputIndex7]->BuildValue()->isa() && + !input_args[kInputIndex7]->BuildValue()->isa()) { + CheckIndices(input_args); + auto alpha_abstract = input_args[kInputIndex7]->cast(); + auto alpha_value_ptr = alpha_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(alpha_value_ptr); + auto alpha_tensor = alpha_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(alpha_tensor); + auto alpha_dtype = input_args[kInputIndex7]->BuildType(); + MS_EXCEPTION_IF_NULL(alpha_dtype); + auto alpha_type_id = alpha_dtype->cast(); + MS_EXCEPTION_IF_NULL(alpha_type_id); + auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast()->element(); + auto alpha_type_element = alpha_type_id->element(); + float real = 0; + int32_t imag = 0; + if (alpha_type_element->type_id() == kNumberTypeComplex64) { + auto value = reinterpret_cast *>(alpha_tensor->data_c()); + real = value[0].real(); + imag = value[0].imag(); + } else if (alpha_type_element->type_id() == kNumberTypeComplex128) { + auto value = reinterpret_cast *>(alpha_tensor->data_c()); + real = value[0].real(); + imag = value[0].imag(); + } + if (imag != 0 || (expect_dtype->type_id() == kNumberTypeUInt8 && real < 0)) { + MS_EXCEPTION(ValueError) << "For " << op_name + << ", alpha value cannot be converted to type uint8 , without overflow: (" << real + << ", " << imag << ")."; + } + if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) { + int64_t compute_val = + GetInt64AlphaData(alpha_tensor->data_c(), alpha_type_element->type_id(), expect_dtype, real); + if (GetDtypeMinAndMaxAndCheckOverFlow(expect_dtype, compute_val)) { + PrintAlphaValueError(alpha_type_element->type_id(), expect_dtype, compute_val, real, imag); + } + } + } else { + MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cant't get."; + } + auto x2_indices_abstract = input_args[kInputIndex3]->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_abstract); + auto x2_indices_value_ptr = x2_indices_abstract->BuildValue(); + MS_EXCEPTION_IF_NULL(x2_indices_value_ptr); + auto x2_indices_tensor = x2_indices_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_tensor); + auto x2_indices_type = input_args[kInputIndex3]->BuildType(); + MS_EXCEPTION_IF_NULL(x2_indices_type); + auto x2_indices_type_id = x2_indices_type->cast(); + MS_EXCEPTION_IF_NULL(x2_indices_type_id); + auto x2_indices_type_element = x2_indices_type_id->element(); + MS_EXCEPTION_IF_NULL(x2_indices_type_element); + int64_t x2_indices_unique_size = 0; + if (x2_indices_type_element->type_id() == kNumberTypeInt32) { + x2_indices_unique_size = compute_output_indices_unique_size_int32( + reinterpret_cast(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize()); + } else { + x2_indices_unique_size = compute_output_indices_unique_size_int64( + reinterpret_cast(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize()); + } + auto x1_indices_shape = + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; + auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape]; + int64_t x2_indices_shape_right = x2_indices_unique_size * x3_shape[1] + x1_indices_shape[1]; + std::vector output_indices_shape = {2, x2_indices_shape_right}; + abstract::ShapePtr output_indices_shape_list = + std::make_shared(output_indices_shape, output_indices_shape, output_indices_shape); + std::vector output_values_shape = {x2_indices_shape_right}; + abstract::ShapePtr output_values_shape_list = + std::make_shared(output_values_shape, output_values_shape, output_values_shape); + auto input_shape = input_args[kInputIndex2]->BuildShape(); + abstract::ShapePtr output_shape_shape_list = input_shape->cast(); + return std::make_shared(std::vector{ + output_indices_shape_list, output_values_shape_list, output_shape_shape_list}); + } else { + std::vector output_shape = {abstract::Shape::SHP_ANY}; + std::vector infer_shape_min = {0}; + std::vector infer_shape_max = {MAX_LEN}; + abstract::ShapePtr output_shape_list = + std::make_shared(output_shape, infer_shape_min, infer_shape_max); + return std::make_shared( + std::vector{output_shape_list, output_shape_list, output_shape_list}); + } +} + +TuplePtr SspaddmmInferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto op_name = prim->name(); + std::map x1_args = {{"x1_indices", input_args[kInputIndex0]->BuildType()}, + {"x1_shape", input_args[kInputIndex2]->BuildType()}}; + (void)CheckAndConvertUtils::CheckTensorTypeSame(x1_args, {kInt32, kInt64}, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x1_values", input_args[kInputIndex1]->BuildType(), + {kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64}, + op_name); + std::map x2_args = {{"x2_indices", input_args[kInputIndex3]->BuildType()}, + {"x2_shape", input_args[kInputIndex5]->BuildType()}}; + (void)CheckAndConvertUtils::CheckTensorTypeSame(x2_args, {kInt32, kInt64}, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x2_values", input_args[kInputIndex4]->BuildType(), + {kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64}, + op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x3_dense", input_args[kInputIndex6]->BuildType(), + {kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64}, + op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid( + "alpha", input_args[kInputIndex7]->BuildType(), + {kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name); + (void)CheckAndConvertUtils::CheckTensorTypeValid( + "beta", input_args[kInputIndex8]->BuildType(), + {kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name); + auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast()->element(); + auto beta_dtype = input_args[kInputIndex8]->BuildType()->cast()->element(); + if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) { + auto beta_dtype_id = beta_dtype->type_id(); + if (beta_dtype_id == kNumberTypeFloat16 || beta_dtype_id == kNumberTypeFloat32 || + beta_dtype_id == kNumberTypeFloat64) { + MS_EXCEPTION(TypeError) << "For " << op_name << ",beta dtype: " << beta_dtype->ToString() + << " can't convert to the desired output type: " << expect_dtype->ToString() << "."; + } + } + std::map args = {{"x1_values", input_args[kInputIndex1]->BuildType()}, + {"x2_values", input_args[kInputIndex4]->BuildType()}, + {"x3_dense", input_args[kInputIndex6]->BuildType()}}; + auto output_values_type = CheckAndConvertUtils::CheckTensorTypeSame( + args, {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat32, kFloat64}, op_name); + return std::make_shared(std::vector{kInt64, output_values_type, kInt64}); +} +} // namespace + +AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + const int64_t input_num = 9; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); + CheckInputTensor(input_args); + auto infer_type = SspaddmmInferType(primitive, input_args); + auto infer_shape = SspaddmmInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} + +REGISTER_PRIMITIVE_EVAL_IMPL(Sspaddmm, prim::kPrimSspaddmm, SspaddmmInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sspaddmm.h b/mindspore/core/ops/sspaddmm.h new file mode 100644 index 00000000000..ba858b39ffe --- /dev/null +++ b/mindspore/core/ops/sspaddmm.h @@ -0,0 +1,49 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SSPADDMM_H_ +#define MINDSPORE_CORE_OPS_SSPADDMM_H_ +#include +#include +#include + +#include "ops/base_operator.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSspaddmm = "Sspaddmm"; +/// \brief Performs a matrix multiplication of the matrices mat1 and mat2. +/// The matrix input is added to the final result. +/// Refer to Python API @ref mindspore.ops.Sspaddmm for more details. +class MIND_API Sspaddmm : public BaseOperator { + public: + MIND_API_BASE_MEMBER(Sspaddmm); + /// \brief Constructor. + Sspaddmm() : BaseOperator(kNameSspaddmm) { + InitIOName( + {"x1_indices", "x1_values", "x1_shape", "x2_indices", "x2_values", "x2_shape", "x3_dense", "alpha", "beta"}, + {"y_indices", "y_values", "y_shape"}); + } + /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Sspaddmm for the inputs. + void Init() {} +}; +abstract::AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SSPADDMM_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 1fdb131b11b..45e2e2e2e86 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -209,3 +209,4 @@ from .scatter_nd_max import _scatter_nd_max_aicpu from .conj import _conj_aicpu from .scatter_nd_min import _scatter_nd_min_aicpu from .cholesky import _cholesky_aicpu +from .sspaddmm import _sspaddmm_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/sspaddmm.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/sspaddmm.py new file mode 100644 index 00000000000..48f0963fe0f --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/sspaddmm.py @@ -0,0 +1,97 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sspaddmm op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + + +def _reg_aicpu(): + return AiCPURegOp("Sspaddmm") \ + .fusion_type("OPAQUE") \ + .input(0, "x1_indices", "required") \ + .input(1, "x1_values", "required") \ + .input(2, "x1_shape", "required") \ + .input(3, "x2_indices", "required") \ + .input(4, "x2_values", "required") \ + .input(5, "x2_shape", "required") \ + .input(6, "x3_dense", "required") \ + .input(7, "alpha", "required") \ + .input(8, "beta", "required") \ + .output(0, "y_indices", "required") \ + .output(1, "y_values", "required") \ + .output(2, "y_shape", "required") + + +def _reg_format(op_info, dtype1, dtype2, dtype3, alpha, beta): + return op_info.dtype_format(dtype1, dtype3, dtype1, dtype2, dtype3, dtype2, dtype3, alpha, beta, + DataType.I64_Default, dtype3, DataType.I64_Default) + + +def _reg_format_beta(op_info, dtype1, dtype2, alpha, beta): + op_info = _reg_format(op_info, dtype1, dtype2, DataType.U8_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.I8_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.I16_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.I32_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.I64_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.F32_Default, alpha, beta) + op_info = _reg_format(op_info, dtype1, dtype2, DataType.F64_Default, alpha, beta) + return op_info + + +def _reg_format_alpha(op_info, dtype1, dtype2, alpha): + """alpha reg""" + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U8_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U16_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U32_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U64_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I8_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I16_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I32_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I64_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F16_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F32_Default) + op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F64_Default) + return op_info + + +def _reg_format_indices(op_info, dtype1, dtype2): + """indices reg""" + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U8_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U16_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U32_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U64_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I8_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I16_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I32_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I64_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F16_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F32_Default) + op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F64_Default) + return op_info + + +def _reg_format_indices_shape(op_info): + """shape reg""" + op_info = _reg_format_indices(op_info, DataType.I32_Default, DataType.I32_Default) + return op_info.get_op_info() + + +sspaddmm_op_info = _reg_format_indices_shape(_reg_aicpu()) + + +@op_info_register(sspaddmm_op_info) +def _sspaddmm_aicpu(): + """Sspaddmm AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/cpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/cpu/__init__.py index 05e3a26b9f3..8be63957df1 100644 --- a/mindspore/python/mindspore/ops/_op_impl/cpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/cpu/__init__.py @@ -78,3 +78,4 @@ from .buffer_sample import _buffer_sample_cpu from .priority_replay_buffer import _prb_push_op_cpu from .priority_replay_buffer import _prb_sample_op_cpu from .space_to_batch_nd import _space_to_batch_nd_cpu +from .sspaddmm import _sspaddmm_cpu diff --git a/mindspore/python/mindspore/ops/_op_impl/cpu/sspaddmm.py b/mindspore/python/mindspore/ops/_op_impl/cpu/sspaddmm.py new file mode 100644 index 00000000000..ae2e4229cd8 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/cpu/sspaddmm.py @@ -0,0 +1,95 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sspaddmm op""" +from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType + +sspaddmm_op_info = CpuRegOp("Sspaddmm") \ + .input(0, "x1_indices", "required") \ + .input(1, "x1_values", "required") \ + .input(2, "x1_shape", "required") \ + .input(3, "x2_indices", "required") \ + .input(4, "x2_values", "required") \ + .input(5, "x2_shape", "required") \ + .input(6, "x3_dense", "required") \ + .input(7, "alpha", "required") \ + .input(8, "beta", "required") \ + .output(0, "y_indices", "required") \ + .output(1, "y_values", "required") \ + .output(2, "y_shape", "required") + + +def _reg_format(op_info, dtype, alpha, beta): + return op_info.dtype_format(DataType.I32_Default, dtype, DataType.I32_Default, DataType.I32_Default, dtype, + DataType.I32_Default, dtype, alpha, beta, DataType.I64_Default, dtype, + DataType.I64_Default) + + +def _reg_format_beta(op_info, alpha, beta): + op_info = _reg_format(op_info, DataType.U8_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.I8_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.I16_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.I32_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.I64_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.F32_Default, alpha, beta) + op_info = _reg_format(op_info, DataType.F64_Default, alpha, beta) + return op_info + + +def _reg_format_alpha(op_info, alpha): + """alpha reg""" + op_info = _reg_format_beta(op_info, alpha, DataType.U8_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.U16_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.U32_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.U64_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.I8_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.I16_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.I32_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.I64_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.F16_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.F32_Default) + op_info = _reg_format_beta(op_info, alpha, DataType.F64_Default) + return op_info + + +def _reg_format_indices(op_info): + """indices reg""" + op_info = _reg_format_alpha(op_info, DataType.U8_Default) + op_info = _reg_format_alpha(op_info, DataType.U16_Default) + op_info = _reg_format_alpha(op_info, DataType.U32_Default) + op_info = _reg_format_alpha(op_info, DataType.U64_Default) + op_info = _reg_format_alpha(op_info, DataType.I8_Default) + op_info = _reg_format_alpha(op_info, DataType.I16_Default) + op_info = _reg_format_alpha(op_info, DataType.I32_Default) + op_info = _reg_format_alpha(op_info, DataType.I64_Default) + op_info = _reg_format_alpha(op_info, DataType.F16_Default) + op_info = _reg_format_alpha(op_info, DataType.F32_Default) + op_info = _reg_format_alpha(op_info, DataType.F64_Default) + return op_info + + +def _reg_format_indices_shape(op_info): + """shape reg""" + op_info = _reg_format_indices(op_info) + return op_info.get_op_info() + + +sspaddmm_op_info_all = _reg_format_indices_shape(sspaddmm_op_info) + + +@op_info_register(sspaddmm_op_info_all) +def _sspaddmm_cpu(): + """Sspaddmm cpu register""" + return diff --git a/mindspore/python/mindspore/ops/operations/sparse_ops.py b/mindspore/python/mindspore/ops/operations/sparse_ops.py index bc31bbc66db..358a17ca0b0 100644 --- a/mindspore/python/mindspore/ops/operations/sparse_ops.py +++ b/mindspore/python/mindspore/ops/operations/sparse_ops.py @@ -379,3 +379,103 @@ class DenseToDenseSetOperation(Primitive): self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y_indices', 'y_values', 'y_shape']) validator.check_value_type("set_operation", set_operation, [str], self.name) validator.check_value_type("validate_indices", validate_indices, [bool], self.name) + + +class Sspaddmm(Primitive): + r""" + Matrix multiplies a sparse tensor `x2` with a dense tensor `x3`, then adds the sparse tensor `x1`. + If `x1_shape` is :math:`(s0, s1)`, `x2_shpae` should be :math:`(s0, s2)`, the `x3_shape` should be :math:`(s2, s1)`. + + .. warning:: + This is an experimental prototype that is subject to change and/or deletion. + + .. math:: + out =\beta * x1 + \alpha * (x2 @ x3), + + Inputs: + - **x1_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor. + Support int32, int64. The shape is :math:`(2, n)`. If `x1_shape` is :math:`(s0, s1)`, the row index + value of `x1_indices` should be a non-negative and less than `s0` int number, the col index value of + `x1_indices` should be a non-negative and less than `s1` int number. + - **x1_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in + the `x1_indices`. Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as + `x2_values` and `x3_dense`. The shape should be :math:`(n,)`. + - **x1_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32, int64, + have 2 positive int elements, shape is :math:`(2,)`. The dtype should be the same as `x1_indices`. + - **x2_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor. + Support int32, int64. The shape is :math:`(2, n)`. If `x2_shape` is :math:`(s0, s2)`, the row index + value of `x2_indices` should be a non-negative and less than `s0` int number, the col index value of + `x2_indices` should be a non-negative and less than `s2` int number. + - **x2_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `x2_indices`. + Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as `x1_values` + and `x3_dense`. The shape should be :math:`(n,)`. + - **x2_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32,int64, + have 2 positive int elements, shape is :math:`(2,)`. The dtype is same as `x2_indices`. + - **x3_dense** (Tensor) - A 2-D Tensor, the dtype should be the same as `x2_values` and `x3_dense`. + - **alpha** (Tensor) - A 0-D or 1-D Tensor, the weight of x1. If alpha is 1-D tensor, + the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64, + int8, int16, int32, int64, float16, float32, float64. If the dtype of alpha is not the same with expected + output dtype, alpha value should be convert without overflow. + - **beta** (Tensor) - A 0-D or 1-D, the weight of x2@x3. If alpha is 1-D tensor, + the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64, + int8, int16, int32, int64, float16, float32, float64. If the `x1_values` dtype is byte, char, short, int, + long, the dtype of beta doesn't support float16, float32, float64. + + Outputs: + - **y_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor. + The dtype is int64, each element value should be a non-negative int number. The shape is :math:`(2, n)`. + - **y_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `y_indices`. + The dtype is the same as `x1_values` . The shape should be :math:`(n,)`. + - **y_shape** (Tensor) - A 1-D Tensor, A positive int tuple which specifies the shape of sparse tensor. + The dtype is int64, the values is the same as `x1_shape`. + + Raises: + TypeError: If dtype of `x1_indices`, `x1_shape` is not the same and neither int32 nor int64. + TypeError: If dtype of `x2_indices`, `x2_shape` is not the same and not int32 or int64. + TypeError: If type of `x1_values`, `x2_values`, `x3_dense` is not the same. + TypeError: If dtype of `x1_values`, `x2_values`, `x3_dense` is not uint8, int8, int16, int32, int64, float32, + float64. + ValueError: If shape of `x1_indices`, `x2_indices` is not (2, n). + ValueError: If shape of `x1_values`, `x2_values` is not (n,). + ValueError: If dim0 size of `x1_values` is not the same with dim1 size of `x1_indices`. + ValueError: If dim0 size of `x2_values` is not the same with dim1 size of `x2_indices`. + ValueError: If shape of `x1_shape` or shape of `x2_shape` is not (2,). + ValueError: If dim of `x3_dense` is not 2D. + ValueError: If dtype of `alpha` is not the same with `x2_values` dtype, and alpha value convert to the + `x2_values` dtype overflow. + TypeError: If dtype of `alpha`, `beta` is not uint8, uint16, uint32, uint64, int8, int16, int32, int64, + float16, float32, float64. + TypeError: If the `x1_values` dtype is byte, char, short, int, long, while the dtype of beta is float16, + float32 or float64. + ValueError: If the shape of `alpha`, `beta` is not () or (1,). + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int64) + >>> x1_values = Tensor(np.array([1, 2]), mstype.int32) + >>> x1_shape = Tensor(np.array([3, 3]), mstype.int64) + >>> x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int64) + >>> x2_values = Tensor(np.array([3, 4]), mstype.int32) + >>> x2_shape = Tensor(np.array([3, 3]), mstype.int64) + >>> x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), mstype.int32) + >>> alpha = Tensor(np.array(1), mstype.int32) + >>> beta = Tensor(np.array(1), mstype.int32) + >>> sspaddmm = ops.Sspaddmm() + >>> out_indices, out_values, out_shapes = sspaddmm(x1_indices, x1_values, x1_shape, + ... x2_indices, x2_values, x2_shape, x3_dense, alpha, beta) + >>> print(out_indices) + [[0 1 0 0 0 1 1 1] + [0 1 0 1 2 0 1 2]] + >>> print(out_values) + [ 1 2 9 6 3 12 8 4] + >>> print(out_shapes) + [3 3] + """ + + @prim_attr_register + def __init__(self): + """Initialize Sspaddmm.""" + self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2_indices', 'x2_values', 'x2_shape', + 'x3_dense', 'alpha', 'beta'], outputs=['y_indices', 'y_values', 'y_shape']) diff --git a/tests/mindspore_test_framework/utils/block_util.py b/tests/mindspore_test_framework/utils/block_util.py index cf76a144d0a..ba900cd4c33 100644 --- a/tests/mindspore_test_framework/utils/block_util.py +++ b/tests/mindspore_test_framework/utils/block_util.py @@ -225,6 +225,11 @@ class InputOpNet(nn.Cell): return x + def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9): + x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9) + return x + + def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False): if isinstance(op, nn.Cell): return op diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 789ffa6217d..aabe9eb64c6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -73,7 +73,7 @@ from mindspore.ops.operations.nn_ops import MaxPoolV1 from mindspore.ops.operations.array_ops import NonZero from mindspore.ops.operations._grad_ops import MaxPoolGradV1 from mindspore.ops.operations.nn_ops import ReLUV3 -from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix +from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm from mindspore.ops.operations.other_ops import BlackmanWindow from mindspore.nn.layer import normalization from mindspore.ops.operations.array_ops import RightShift @@ -1605,6 +1605,18 @@ test_case_math_ops = [ 'desc_inputs': [Tensor(np.array([[1, 0], [0, 1]]).astype(np.float32)), Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))], 'skip': ['backward']}), + ('Sspaddmm', { + 'block': Sspaddmm(), + 'desc_inputs': [Tensor(np.array([[0, 1], [0, 1]]).astype(np.int64)), + Tensor(np.array([1, 2]).astype(np.int32)), + Tensor(np.array([3, 3]).astype(np.int64)), + Tensor(np.array([[0, 1], [2, 2]]).astype(np.int64)), + Tensor(np.array([3, 4]).astype(np.int32)), + Tensor(np.array([3, 3]).astype(np.int64)), + Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]).astype(np.int32)), + Tensor(np.array(1).astype(np.int32)), + Tensor(np.array(1).astype(np.int32))], + 'skip': ['backward']}), ('Embedding_1', { 'block': Embedding(vocab_size=10, embedding_size=3), 'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],