update aicpu operator sspaddmm

This commit is contained in:
Lan-Ling 2022-06-14 12:52:25 +08:00
parent 7f56bf8e60
commit 16aad8be1b
13 changed files with 1292 additions and 2 deletions

View File

@ -0,0 +1,371 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h"
#include <algorithm>
#include <complex>
#include <map>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 9;
constexpr size_t kOutputsNum = 3;
constexpr char kKernelName[] = "Sspaddmm";
#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \
if (dtype1 == kNumberTypeInt32) { \
CheckSparseIndicesLegal<int32_t, int32_t>(indices, shapes, num, x_name); \
} else { \
CheckSparseIndicesLegal<int64_t, int64_t>(indices, shapes, num, x_name); \
}
} // namespace
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kKernelName);
}
void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node);
output_values_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
input_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
input_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
mat1_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex3);
mat1_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex5);
alpha_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex7);
beta_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex8);
auto input_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0);
auto mat1_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex3);
auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6);
input_values_num_ = input_indices_shape[1];
mat1_values_num_ = mat1_indices_shape[1];
y_values_num_ = y_indices_shape[1];
mat2_row_ = mat2_shape[0];
mat2_col_ = mat2_shape[1];
}
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
switch (output_values_dtype_) {
case kNumberTypeUInt8: {
LaunchKernel<uint8_t>(inputs, outputs);
break;
}
case kNumberTypeInt8: {
LaunchKernel<int8_t>(inputs, outputs);
break;
}
case kNumberTypeInt16: {
LaunchKernel<int16_t>(inputs, outputs);
break;
}
case kNumberTypeInt32: {
LaunchKernel<int32_t>(inputs, outputs);
break;
}
case kNumberTypeInt64: {
LaunchKernel<int64_t>(inputs, outputs);
break;
}
case kNumberTypeFloat32: {
LaunchKernel<float>(inputs, outputs);
break;
}
case kNumberTypeFloat64: {
LaunchKernel<double>(inputs, outputs);
break;
}
default: {
MS_EXCEPTION(TypeError) << "For Sspaddmm, The output dtype error.";
}
}
return true;
}
template <typename T>
void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input_indices_addr = inputs[0]->addr;
auto input_values_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto input_shape_addr = inputs[2]->addr;
auto mat1_indices_addr = inputs[3]->addr;
auto mat1_values_addr = reinterpret_cast<T *>(inputs[4]->addr);
auto mat1_shape_addr = inputs[5]->addr;
auto mat2_addr = reinterpret_cast<T *>(inputs[6]->addr);
auto alpha_val_addr = inputs[7]->addr;
auto beta_val_addr = inputs[8]->addr;
auto y_indices_addr = reinterpret_cast<int64_t *>(outputs[0]->addr);
auto y_values_addr = reinterpret_cast<T *>(outputs[1]->addr);
auto y_shape_addr = reinterpret_cast<int64_t *>(outputs[2]->addr);
CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_,
"x1");
CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_,
"x2");
int64_t mat1_row, mat1_col, input_row, input_col;
if (mat1_shape_dtype_ == kNumberTypeInt32) {
auto mat1_shape_val = reinterpret_cast<int32_t *>(mat1_shape_addr);
mat1_row = static_cast<int64_t>(mat1_shape_val[0]);
mat1_col = static_cast<int64_t>(mat1_shape_val[1]);
} else {
auto mat1_shape_val = reinterpret_cast<int64_t *>(mat1_shape_addr);
mat1_row = mat1_shape_val[0];
mat1_col = mat1_shape_val[1];
}
if (input_shape_dtype_ == kNumberTypeInt32) {
auto input_shape_val = reinterpret_cast<int32_t *>(input_shape_addr);
input_row = static_cast<int64_t>(input_shape_val[0]);
input_col = static_cast<int64_t>(input_shape_val[1]);
} else {
auto input_shape_val = reinterpret_cast<int64_t *>(input_shape_addr);
input_row = input_shape_val[0];
input_col = input_shape_val[1];
}
if (mat1_col != static_cast<int64_t>(mat2_row_)) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse tensor x2's dense shape col:" << mat1_col
<< " should be equal to x3_dense shape row:" << mat2_row_ << ".";
}
if (mat1_row != input_row || static_cast<int64_t>(mat2_col_) != input_col) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse x1's dense shape "
<< "[" << input_row << "," << input_col << "] should equal to x2@x3_dense shape ["
<< mat1_row << "," << mat2_col_ << "].";
}
if (input_shape_dtype_ == kNumberTypeInt32) {
InitShape<int32_t>(input_shape_addr, y_shape_addr);
} else {
InitShape<int64_t>(input_shape_addr, y_shape_addr);
}
ClearSparseValues<T>(y_values_addr, y_values_num_);
// scalar * sparse inplace
T *input_values_addr_bak = ScalarSparseMul<T>(input_values_addr, beta_val_addr, input_values_num_, beta_dtype_);
T *mat1_values_addr_bak = ScalarSparseMul<T>(mat1_values_addr, alpha_val_addr, mat1_values_num_, alpha_dtype_);
// sparse + sparse
if (input_indices_dtype_ == kNumberTypeInt32) {
SparseAddSparse<int32_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
y_values_addr, y_values_num_);
} else {
SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
y_values_addr, y_values_num_);
}
auto row = mat1_col;
auto col = input_col;
if (mat1_indices_dtype_ == kNumberTypeInt32) {
SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col);
} else {
SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col);
}
}
template <typename T, typename S>
void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num,
std::string x_name) {
auto indices_val = reinterpret_cast<T *>(indices_addr);
auto shape_val = reinterpret_cast<S *>(shape_addr);
int shape_num = 2;
for (int i = 0; i < shape_num; i++) {
if (shape_val[i] <= 0) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_value should be positive"
<< " while get shape [" << shape_val[0] << ", " << shape_val[1] << "]";
}
}
for (size_t i = 0; i < num; i++) {
int64_t row = static_cast<int64_t>(shape_val[0]);
int64_t col = static_cast<int64_t>(shape_val[1]);
int64_t indices_row = static_cast<int64_t>(indices_val[i]);
int64_t indices_col = static_cast<int64_t>(indices_val[i + num]);
if ((indices_row >= row) || indices_col >= col || indices_row < 0 || indices_col < 0) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_indices"
<< " row value:" << indices_row << ", col value: " << indices_col << " out of bounds."
<< " Row should between [0," << row << "].Col should between [0," << col << "].";
}
}
}
template <typename T>
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) {
auto input_shape_val = reinterpret_cast<T *>(input_shape);
size_t shape_num = 2;
for (size_t i = 0; i < shape_num; i++) {
y_shape[i] = static_cast<int64_t>(input_shape_val[i]);
}
}
template <typename T>
void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
sparse_val[i] = static_cast<T>(0);
}
};
ParallelLaunchAutoSearch(task, data_num, this, &parallel_search_info_);
}
// scalar * sparse matrix for beta * input alpha * mat1
template <typename T>
T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) {
T val;
if (!(data_num > 0)) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. ";
}
T *sparse_val_bak = new T[data_num];
switch (tid) {
case kNumberTypeUInt8:
val = static_cast<T>(reinterpret_cast<uint8_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt16:
val = static_cast<T>(reinterpret_cast<uint16_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt32:
val = static_cast<T>(reinterpret_cast<uint32_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt64:
val = static_cast<T>(reinterpret_cast<uint64_t *>(scalar_val)[0]);
break;
case kNumberTypeInt8:
val = static_cast<T>(reinterpret_cast<int8_t *>(scalar_val)[0]);
break;
case kNumberTypeInt16:
val = static_cast<T>(reinterpret_cast<int16_t *>(scalar_val)[0]);
break;
case kNumberTypeInt32:
val = static_cast<T>(reinterpret_cast<int32_t *>(scalar_val)[0]);
break;
case kNumberTypeInt64:
val = static_cast<T>(reinterpret_cast<int64_t *>(scalar_val)[0]);
break;
case kNumberTypeFloat16:
val = static_cast<T>(reinterpret_cast<float16 *>(scalar_val)[0]);
break;
case kNumberTypeFloat32:
val = static_cast<T>(reinterpret_cast<float *>(scalar_val)[0]);
break;
case kNumberTypeFloat64:
val = static_cast<T>(reinterpret_cast<double *>(scalar_val)[0]);
break;
case kNumberTypeBool:
val = static_cast<T>(reinterpret_cast<bool *>(scalar_val)[0]);
break;
case kNumberTypeComplex64:
val = static_cast<T>(reinterpret_cast<std::complex<float> *>(scalar_val)[0].real());
break;
case kNumberTypeComplex128:
val = static_cast<T>(reinterpret_cast<std::complex<double> *>(scalar_val)[0].real());
break;
default:
MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. ";
break;
}
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
sparse_val_bak[i] = sparse_val[i] * val;
}
};
ParallelLaunchAutoSearch(task, data_num, this, &parallel_search_info_);
return sparse_val_bak;
}
// sparse matrix add sparse matrix
// input + mat1 @ mat2
template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices,
S *y_values, size_t y_num) {
// to implement m1[row][col] = vals
auto input_ids = reinterpret_cast<T *>(input_indices);
this->cnt_ = input_num;
// get output vals and index addr
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
auto row = input_ids[i];
auto col = input_ids[i + input_num];
y_values[i] = input_values[i];
y_indices[i] = static_cast<int64_t>(row);
y_indices[i + y_num] = static_cast<int64_t>(col);
}
};
ParallelLaunchAutoSearch(task, input_num, this, &parallel_search_info_);
}
template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row,
int64_t mat2_col_) {
// the result of mat1 @ mat2 will write to output directly
auto mat1_ids = reinterpret_cast<T *>(mat1_indices);
std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt;
std::unordered_map<T, std::vector<T>> unrepeated;
std::unordered_map<T, std::unordered_map<T, std::vector<S>>> co_map_idx;
// unrepeated : [1 -> [0], 2 -> [1, 2]]
// co_map_idx : [1][0] -> 0.3
for (size_t i = 0; i < mat1_vals_num; i++) {
T _row = mat1_ids[i];
T _col = mat1_ids[i + mat1_vals_num];
unrepeated[_row].push_back(_col);
co_map_idx[_row][_col].push_back(mat1_values[i]);
for (int64_t j = 0; j < mat2_col_; j++) {
if (idx_map_cnt[_row][j] == 0) {
idx_map_cnt[_row][j] = this->cnt_;
this->cnt_++;
}
}
}
std::vector<T> res;
for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) {
res.push_back(it->first);
}
size_t n_unreapeat = unrepeated.size();
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
// get val
auto row_mat1 = res[i];
for (auto row_mat2 : unrepeated[row_mat1]) {
S val = co_map_idx[row_mat1][row_mat2].back();
co_map_idx[row_mat1][row_mat2].pop_back();
for (int64_t j = 0; j < mat2_col_; j++) {
// get val
T idx = idx_map_cnt[row_mat1][j];
*(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j];
y_indices[idx] = static_cast<int64_t>(row_mat1);
y_indices[idx + y_vals_num] = j;
}
}
}
};
ParallelLaunchAutoSearch(task, n_unreapeat, this, &parallel_search_info_);
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Sspaddmm, SspaddmmCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_
#include <unordered_map>
#include <vector>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod {
public:
SspaddmmCPUKernelMod() = default;
~SspaddmmCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void CheckParam(const CNodePtr &kernel_node);
template <typename T, typename S>
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name);
template <typename T>
void InitShape(void *input_shape, int64_t *y_shape);
template <typename T>
void ClearSparseValues(T *sparse_val, size_t data_num);
template <typename T>
T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid);
template <typename T, typename S>
void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
size_t y_num);
template <typename T, typename S>
void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col);
TypeId output_values_dtype_{kTypeUnknown};
TypeId input_indices_dtype_{kTypeUnknown};
TypeId input_shape_dtype_{kTypeUnknown};
TypeId mat1_indices_dtype_{kTypeUnknown};
TypeId mat1_shape_dtype_{kTypeUnknown};
TypeId alpha_dtype_{kTypeUnknown};
TypeId beta_dtype_{kTypeUnknown};
size_t input_values_num_{0};
size_t mat1_values_num_{0};
size_t y_values_num_{0};
size_t cnt_{0};
size_t mat2_row_{0};
size_t mat2_col_{0};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_

View File

@ -91,6 +91,7 @@ PrimShapeDependMap &GetHostDependsMap() {
static const auto &kSegmentSum = prim::kPrimSegmentSum->name();
static const auto &kBlackmanWindow = prim::kPrimBlackmanWindow->name();
static const auto &kExpand = prim::kPrimExpand->name();
static const auto &kSspaddmm = prim::kPrimSspaddmm->name();
// Common host depends.
static PrimShapeDependMap host_depends{{kSegmentMax, ShapeSet{1}},
{kSegmentMin, ShapeSet{1}},
@ -132,7 +133,8 @@ PrimShapeDependMap &GetHostDependsMap() {
{kSparseSegmentMean, ShapeSet{2}},
{kResizeLinear1D, ShapeSet{1}},
{kBlackmanWindow, ShapeSet{0}},
{kExpand, ShapeSet{1}}};
{kExpand, ShapeSet{1}},
{kSspaddmm, ShapeSet{0, 2, 3, 5, 7}}};
return host_depends;
}

View File

@ -340,6 +340,7 @@ GVAR_DEF(PrimitivePtr, kPrimGatherNd, std::make_shared<Primitive>("GatherNd"));
GVAR_DEF(PrimitivePtr, kPrimSparseGatherV2, std::make_shared<Primitive>("SparseGatherV2"));
GVAR_DEF(PrimitivePtr, kPrimCoalesce, std::make_shared<Primitive>(kCoalesce));
GVAR_DEF(PrimitivePtr, kPrimSparseToDense, std::make_shared<Primitive>("SparseToDense"));
GVAR_DEF(PrimitivePtr, kPrimSspaddmm, std::make_shared<Primitive>("Sspaddmm"));
GVAR_DEF(PrimitivePtr, kPrimShape, std::make_shared<Primitive>("Shape"));
GVAR_DEF(PrimitivePtr, kPrimStridedSlice, std::make_shared<Primitive>(kStridedSlice));
GVAR_DEF(PrimitivePtr, kPrimStridedSliceGrad, std::make_shared<Primitive>(kStridedSliceGrad));

View File

@ -0,0 +1,484 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <vector>
#include <memory>
#include <complex>
#include <map>
#include <string>
#include <climits>
#include "ops/sspaddmm.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
const int64_t MAX_LEN = 1000000;
int64_t compute_output_indices_unique_size_int32(int32_t *indices, size_t size) {
std::set<int32_t> mat1_indices_set;
size_t half_size = size / 2;
for (size_t i = 0; i < half_size; i++) {
mat1_indices_set.insert(indices[i]);
}
return mat1_indices_set.size();
}
int64_t compute_output_indices_unique_size_int64(int64_t *indices, size_t size) {
std::set<int64_t> mat1_indices_set;
size_t half_size = size / 2;
for (size_t i = 0; i < half_size; i++) {
mat1_indices_set.insert(indices[i]);
}
return mat1_indices_set.size();
}
enum DimNum : size_t {
dim0Num = 0,
dim1Num,
dim2Num,
};
int64_t GetInt64AlphaDataOther(void *values, TypeId tid, TypePtr expect_dtype, float real) {
int64_t compute_val = 0;
bool flag1 = expect_dtype->type_id() == kNumberTypeUInt8 ? true : false;
bool flag2 = false;
switch (tid) {
case kNumberTypeFloat16:
flag2 = (flag1 && (reinterpret_cast<float16 *>(values)[0] < static_cast<float16>(0))) ? true : false;
compute_val = static_cast<int64_t>(reinterpret_cast<float16 *>(values)[0]);
break;
case kNumberTypeFloat32:
flag2 = (flag1 && (reinterpret_cast<float *>(values)[0] < 0)) ? true : false;
compute_val = static_cast<int64_t>(reinterpret_cast<float *>(values)[0]);
break;
case kNumberTypeFloat64:
flag2 = (flag1 && (reinterpret_cast<double *>(values)[0] < 0)) ? true : false;
compute_val = static_cast<int64_t>(reinterpret_cast<double *>(values)[0]);
break;
case kNumberTypeBool:
compute_val = static_cast<int64_t>(reinterpret_cast<bool *>(values)[0]);
break;
case kNumberTypeComplex64:
case kNumberTypeComplex128:
compute_val = static_cast<int64_t>(real);
break;
default:
MS_EXCEPTION(TypeError) << "For Sspaddmm, alpha dtype is not support, only support"
<< " number type and bool, complex64, complex128. ";
break;
}
if (flag2) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cannot be converted to type uint8 without overflow. ";
}
return compute_val;
}
int64_t GetInt64AlphaData(void *values, TypeId tid, TypePtr expect_dtype, float real) {
int64_t compute_val = 0;
switch (tid) {
case kNumberTypeUInt8:
compute_val = static_cast<int64_t>(reinterpret_cast<uint8_t *>(values)[0]);
break;
case kNumberTypeUInt16:
compute_val = static_cast<int64_t>(reinterpret_cast<uint16_t *>(values)[0]);
break;
case kNumberTypeUInt32:
compute_val = static_cast<int64_t>(reinterpret_cast<uint32_t *>(values)[0]);
break;
case kNumberTypeUInt64:
compute_val = static_cast<int64_t>(reinterpret_cast<uint64_t *>(values)[0]);
break;
case kNumberTypeInt8:
compute_val = static_cast<int64_t>(reinterpret_cast<int8_t *>(values)[0]);
break;
case kNumberTypeInt16:
compute_val = static_cast<int64_t>(reinterpret_cast<int16_t *>(values)[0]);
break;
case kNumberTypeInt32:
compute_val = static_cast<int64_t>(reinterpret_cast<int32_t *>(values)[0]);
break;
case kNumberTypeInt64:
compute_val = static_cast<int64_t>(reinterpret_cast<int64_t *>(values)[0]);
break;
default:
compute_val = GetInt64AlphaDataOther(values, tid, expect_dtype, real);
break;
}
return compute_val;
}
void CheckAlphaBeta(const std::vector<AbstractBasePtr> &input_args) {
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->BuildShape())[kShape];
if (!(alpha_shape.size() == dim1Num && alpha_shape[0] == dim1Num)) {
if (alpha_shape.size() != dim0Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha shape should be (1,) or ()"
<< ", but get dim num is " << alpha_shape.size() << ", dim0 size is " << alpha_shape[0]
<< ".";
}
}
if (!(beta_shape.size() == dim1Num && beta_shape[0] == dim1Num)) {
if (beta_shape.size() != dim0Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, beta shape should be (1,) or ()"
<< ", but get dim num is " << beta_shape.size() << ", dim0 size is " << beta_shape[0]
<< ".";
}
}
}
void CheckInputTensor(const std::vector<AbstractBasePtr> &input_args) {
auto x1_indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto x1_values_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto x1_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
auto x2_indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
auto x2_values_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
auto x2_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
if (x1_indices_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices should be a 2-D tensor"
<< ", while x1_indices dim num is " << x1_indices_shape.size() << ".";
}
if (x1_indices_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices shape should be (2, n)"
<< ", while x1_indices shape dim0 is " << x1_indices_shape[0] << ".";
}
if (x1_values_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_values should be a 1-D tensor"
<< ", while x1_values dim num is " << x1_values_shape.size() << ".";
}
if (x1_indices_shape[1] != x1_values_shape[0]) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", dim1 size of `x1_indices` and dim0 size of `x1_values` should be the same"
<< " while x1_indices dim1 size is " << x1_indices_shape[1]
<< ", x1_values_shape dim0 size is " << x1_values_shape[0] << ".";
}
if (x1_shape_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", x1_shape should be a 1-D tensor, while x1_shape dim num is " << x1_shape_shape.size()
<< ".";
}
if (x1_shape_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", the shape of x1_shape should be [2] but got shape [" << x1_shape_shape[0] << "].";
}
if (x2_indices_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices should be a 2-D tensor"
<< ", while x2_indices dim num is " << x2_indices_shape.size() << ".";
}
if (x2_indices_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices shape should be (2, n)"
<< ", while x2_indices shape dim0 is " << x2_indices_shape[0] << ".";
}
if (x2_values_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_values should be a 1-D tensor"
<< ", while x2_values dim num is " << x2_values_shape.size() << ".";
}
if (x2_indices_shape[1] != x2_values_shape[0]) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", dim1 size of `x2_indices` and dim0 size of `x2_values` should be the same"
<< " while x2_indices dim1 size is " << x2_indices_shape[1]
<< ", x2_values_shape dim0 size is " << x2_values_shape[0] << ".";
}
if (x2_shape_shape.size() != dim1Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", x2_shape should be a 1-D tensor, while x2_shape dim num is " << x2_shape_shape.size()
<< ".";
}
if (x2_shape_shape[0] != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", the shape of x2_shape should be [2] but got shape [" << x2_shape_shape[0] << "].";
}
if (x3_shape.size() != dim2Num) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, x3_dense should be a 2-D tensor"
<< ", while dim num is " << x3_shape.size() << ".";
}
CheckAlphaBeta(input_args);
}
template <typename T>
void IndicesBoundCheck(T *indices_val, size_t indices_num, T *shape_val, std::string name) {
if (shape_val[0] <= 0 || shape_val[1] <= 0) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_shape should be positive, "
<< "while got shape [" << shape_val[0] << ", " << shape_val[1] << "].";
}
size_t half_num = indices_num / dim2Num;
for (size_t i = 0; i < half_num; i++) {
if ((indices_val[i] < 0) || (indices_val[i] >= shape_val[0])) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices row index should between [0, " << shape_val[0]
<< "], while got row index " << indices_val[i] << ".";
}
if ((indices_val[i + half_num] < 0) || (indices_val[i + half_num] >= shape_val[1])) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices col index should between [0, " << shape_val[1]
<< "], while got col index " << indices_val[i + half_num] << ".";
}
}
}
void CheckIndices(const std::vector<AbstractBasePtr> &input_args) {
auto x1_indices_abstract = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x1_indices_abstract);
auto x1_indices_value_ptr = x1_indices_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x1_indices_value_ptr);
auto x1_indices_tensor = x1_indices_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x1_indices_tensor);
auto x1_indices_type = input_args[kInputIndex0]->BuildType();
MS_EXCEPTION_IF_NULL(x1_indices_type);
auto x1_indices_type_id = x1_indices_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(x1_indices_type_id);
auto x1_indices_type_element = x1_indices_type_id->element();
MS_EXCEPTION_IF_NULL(x1_indices_type_element);
auto x1_shape_abstract = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x1_shape_abstract);
auto x1_shape_value_ptr = x1_shape_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x1_shape_value_ptr);
auto x1_shape_tensor = x1_shape_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x1_shape_tensor);
if (x1_indices_type_element->type_id() == kNumberTypeInt32) {
IndicesBoundCheck<int32_t>(reinterpret_cast<int32_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
reinterpret_cast<int32_t *>(x1_shape_tensor->data_c()), "x1");
} else {
IndicesBoundCheck<int64_t>(reinterpret_cast<int64_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
reinterpret_cast<int64_t *>(x1_shape_tensor->data_c()), "x1");
}
auto x2_indices_abstract = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x2_indices_abstract);
auto x2_indices_value_ptr = x2_indices_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x2_indices_value_ptr);
auto x2_indices_tensor = x2_indices_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x2_indices_tensor);
auto x2_indices_type = input_args[kInputIndex3]->BuildType();
MS_EXCEPTION_IF_NULL(x2_indices_type);
auto x2_indices_type_id = x2_indices_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(x2_indices_type_id);
auto x2_indices_type_element = x2_indices_type_id->element();
MS_EXCEPTION_IF_NULL(x2_indices_type_element);
auto x2_shape_abstract = input_args[kInputIndex5]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x2_shape_abstract);
auto x2_shape_value_ptr = x2_shape_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x2_shape_value_ptr);
auto x2_shape_tensor = x2_shape_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x2_shape_tensor);
if (x2_indices_type_element->type_id() == kNumberTypeInt32) {
IndicesBoundCheck<int32_t>(reinterpret_cast<int32_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(),
reinterpret_cast<int32_t *>(x2_shape_tensor->data_c()), "x2");
} else {
IndicesBoundCheck<int64_t>(reinterpret_cast<int64_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(),
reinterpret_cast<int64_t *>(x2_shape_tensor->data_c()), "x2");
}
}
bool GetDtypeMinAndMaxAndCheckOverFlow(TypePtr tid, int64_t compute_val) {
int64_t min = 0;
int64_t max = 0;
switch (tid->type_id()) {
case kNumberTypeUInt8:
max = UCHAR_MAX;
min = -UCHAR_MAX - 1;
break;
case kNumberTypeInt8:
max = SCHAR_MAX;
min = SCHAR_MIN;
break;
case kNumberTypeInt16:
max = SHRT_MAX;
min = SHRT_MIN;
break;
case kNumberTypeInt32:
max = INT_MAX;
min = INT_MIN;
break;
default:
max = LONG_MAX;
min = LONG_MIN;
break;
}
if (compute_val <= min || compute_val > max) {
return true;
} else {
return false;
}
}
void PrintAlphaValueError(TypeId aid, TypePtr expect_dtype, int64_t compute_val, float real, int64_t imag) {
if (aid == kNumberTypeComplex64 || aid == kNumberTypeComplex128) {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", alpha cannot be converted to expect dtype " << expect_dtype->ToString()
<< ", without overflow: (" << real << ", " << imag << ").";
} else {
MS_EXCEPTION(ValueError) << "For Sspaddmm"
<< ", alpha cannot be converted to expect x2_values dtype " << expect_dtype->ToString()
<< ", without overflow: " << compute_val << ".";
}
}
abstract::TupleShapePtr SspaddmmInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>() &&
!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
if (input_args[kInputIndex7]->isa<abstract::AbstractTensor>() &&
!input_args[kInputIndex7]->BuildValue()->isa<AnyValue>() &&
!input_args[kInputIndex7]->BuildValue()->isa<None>()) {
CheckIndices(input_args);
auto alpha_abstract = input_args[kInputIndex7]->cast<abstract::AbstractTensorPtr>();
auto alpha_value_ptr = alpha_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(alpha_value_ptr);
auto alpha_tensor = alpha_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(alpha_tensor);
auto alpha_dtype = input_args[kInputIndex7]->BuildType();
MS_EXCEPTION_IF_NULL(alpha_dtype);
auto alpha_type_id = alpha_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(alpha_type_id);
auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
auto alpha_type_element = alpha_type_id->element();
float real = 0;
int32_t imag = 0;
if (alpha_type_element->type_id() == kNumberTypeComplex64) {
auto value = reinterpret_cast<std::complex<float> *>(alpha_tensor->data_c());
real = value[0].real();
imag = value[0].imag();
} else if (alpha_type_element->type_id() == kNumberTypeComplex128) {
auto value = reinterpret_cast<std::complex<double> *>(alpha_tensor->data_c());
real = value[0].real();
imag = value[0].imag();
}
if (imag != 0 || (expect_dtype->type_id() == kNumberTypeUInt8 && real < 0)) {
MS_EXCEPTION(ValueError) << "For " << op_name
<< ", alpha value cannot be converted to type uint8 , without overflow: (" << real
<< ", " << imag << ").";
}
if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) {
int64_t compute_val =
GetInt64AlphaData(alpha_tensor->data_c(), alpha_type_element->type_id(), expect_dtype, real);
if (GetDtypeMinAndMaxAndCheckOverFlow(expect_dtype, compute_val)) {
PrintAlphaValueError(alpha_type_element->type_id(), expect_dtype, compute_val, real, imag);
}
}
} else {
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cant't get.";
}
auto x2_indices_abstract = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x2_indices_abstract);
auto x2_indices_value_ptr = x2_indices_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(x2_indices_value_ptr);
auto x2_indices_tensor = x2_indices_value_ptr->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(x2_indices_tensor);
auto x2_indices_type = input_args[kInputIndex3]->BuildType();
MS_EXCEPTION_IF_NULL(x2_indices_type);
auto x2_indices_type_id = x2_indices_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(x2_indices_type_id);
auto x2_indices_type_element = x2_indices_type_id->element();
MS_EXCEPTION_IF_NULL(x2_indices_type_element);
int64_t x2_indices_unique_size = 0;
if (x2_indices_type_element->type_id() == kNumberTypeInt32) {
x2_indices_unique_size = compute_output_indices_unique_size_int32(
reinterpret_cast<int32_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize());
} else {
x2_indices_unique_size = compute_output_indices_unique_size_int64(
reinterpret_cast<int64_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize());
}
auto x1_indices_shape =
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
int64_t x2_indices_shape_right = x2_indices_unique_size * x3_shape[1] + x1_indices_shape[1];
std::vector<int64_t> output_indices_shape = {2, x2_indices_shape_right};
abstract::ShapePtr output_indices_shape_list =
std::make_shared<abstract::Shape>(output_indices_shape, output_indices_shape, output_indices_shape);
std::vector<int64_t> output_values_shape = {x2_indices_shape_right};
abstract::ShapePtr output_values_shape_list =
std::make_shared<abstract::Shape>(output_values_shape, output_values_shape, output_values_shape);
auto input_shape = input_args[kInputIndex2]->BuildShape();
abstract::ShapePtr output_shape_shape_list = input_shape->cast<abstract::ShapePtr>();
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
output_indices_shape_list, output_values_shape_list, output_shape_shape_list});
} else {
std::vector<int64_t> output_shape = {abstract::Shape::SHP_ANY};
std::vector<int64_t> infer_shape_min = {0};
std::vector<int64_t> infer_shape_max = {MAX_LEN};
abstract::ShapePtr output_shape_list =
std::make_shared<abstract::Shape>(output_shape, infer_shape_min, infer_shape_max);
return std::make_shared<abstract::TupleShape>(
std::vector<abstract::BaseShapePtr>{output_shape_list, output_shape_list, output_shape_list});
}
}
TuplePtr SspaddmmInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = prim->name();
std::map<std::string, TypePtr> x1_args = {{"x1_indices", input_args[kInputIndex0]->BuildType()},
{"x1_shape", input_args[kInputIndex2]->BuildType()}};
(void)CheckAndConvertUtils::CheckTensorTypeSame(x1_args, {kInt32, kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x1_values", input_args[kInputIndex1]->BuildType(),
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
op_name);
std::map<std::string, TypePtr> x2_args = {{"x2_indices", input_args[kInputIndex3]->BuildType()},
{"x2_shape", input_args[kInputIndex5]->BuildType()}};
(void)CheckAndConvertUtils::CheckTensorTypeSame(x2_args, {kInt32, kInt64}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x2_values", input_args[kInputIndex4]->BuildType(),
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("x3_dense", input_args[kInputIndex6]->BuildType(),
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid(
"alpha", input_args[kInputIndex7]->BuildType(),
{kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid(
"beta", input_args[kInputIndex8]->BuildType(),
{kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name);
auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
auto beta_dtype = input_args[kInputIndex8]->BuildType()->cast<TensorTypePtr>()->element();
if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) {
auto beta_dtype_id = beta_dtype->type_id();
if (beta_dtype_id == kNumberTypeFloat16 || beta_dtype_id == kNumberTypeFloat32 ||
beta_dtype_id == kNumberTypeFloat64) {
MS_EXCEPTION(TypeError) << "For " << op_name << ",beta dtype: " << beta_dtype->ToString()
<< " can't convert to the desired output type: " << expect_dtype->ToString() << ".";
}
}
std::map<std::string, TypePtr> args = {{"x1_values", input_args[kInputIndex1]->BuildType()},
{"x2_values", input_args[kInputIndex4]->BuildType()},
{"x3_dense", input_args[kInputIndex6]->BuildType()}};
auto output_values_type = CheckAndConvertUtils::CheckTensorTypeSame(
args, {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat32, kFloat64}, op_name);
return std::make_shared<Tuple>(std::vector<TypePtr>{kInt64, output_values_type, kInt64});
}
} // namespace
AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 9;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
CheckInputTensor(input_args);
auto infer_type = SspaddmmInferType(primitive, input_args);
auto infer_shape = SspaddmmInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Sspaddmm, prim::kPrimSspaddmm, SspaddmmInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SSPADDMM_H_
#define MINDSPORE_CORE_OPS_SSPADDMM_H_
#include <vector>
#include <set>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSspaddmm = "Sspaddmm";
/// \brief Performs a matrix multiplication of the matrices mat1 and mat2.
/// The matrix input is added to the final result.
/// Refer to Python API @ref mindspore.ops.Sspaddmm for more details.
class MIND_API Sspaddmm : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Sspaddmm);
/// \brief Constructor.
Sspaddmm() : BaseOperator(kNameSspaddmm) {
InitIOName(
{"x1_indices", "x1_values", "x1_shape", "x2_indices", "x2_values", "x2_shape", "x3_dense", "alpha", "beta"},
{"y_indices", "y_values", "y_shape"});
}
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Sspaddmm for the inputs.
void Init() {}
};
abstract::AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SSPADDMM_H_

View File

@ -209,3 +209,4 @@ from .scatter_nd_max import _scatter_nd_max_aicpu
from .conj import _conj_aicpu
from .scatter_nd_min import _scatter_nd_min_aicpu
from .cholesky import _cholesky_aicpu
from .sspaddmm import _sspaddmm_aicpu

View File

@ -0,0 +1,97 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sspaddmm op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
def _reg_aicpu():
return AiCPURegOp("Sspaddmm") \
.fusion_type("OPAQUE") \
.input(0, "x1_indices", "required") \
.input(1, "x1_values", "required") \
.input(2, "x1_shape", "required") \
.input(3, "x2_indices", "required") \
.input(4, "x2_values", "required") \
.input(5, "x2_shape", "required") \
.input(6, "x3_dense", "required") \
.input(7, "alpha", "required") \
.input(8, "beta", "required") \
.output(0, "y_indices", "required") \
.output(1, "y_values", "required") \
.output(2, "y_shape", "required")
def _reg_format(op_info, dtype1, dtype2, dtype3, alpha, beta):
return op_info.dtype_format(dtype1, dtype3, dtype1, dtype2, dtype3, dtype2, dtype3, alpha, beta,
DataType.I64_Default, dtype3, DataType.I64_Default)
def _reg_format_beta(op_info, dtype1, dtype2, alpha, beta):
op_info = _reg_format(op_info, dtype1, dtype2, DataType.U8_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I8_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I16_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I32_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I64_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.F32_Default, alpha, beta)
op_info = _reg_format(op_info, dtype1, dtype2, DataType.F64_Default, alpha, beta)
return op_info
def _reg_format_alpha(op_info, dtype1, dtype2, alpha):
"""alpha reg"""
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U8_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U16_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U32_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U64_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I8_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I16_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I32_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I64_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F16_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F32_Default)
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F64_Default)
return op_info
def _reg_format_indices(op_info, dtype1, dtype2):
"""indices reg"""
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U8_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U16_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U32_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U64_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I8_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I16_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I32_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I64_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F16_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F32_Default)
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F64_Default)
return op_info
def _reg_format_indices_shape(op_info):
"""shape reg"""
op_info = _reg_format_indices(op_info, DataType.I32_Default, DataType.I32_Default)
return op_info.get_op_info()
sspaddmm_op_info = _reg_format_indices_shape(_reg_aicpu())
@op_info_register(sspaddmm_op_info)
def _sspaddmm_aicpu():
"""Sspaddmm AiCPU register"""
return

View File

@ -78,3 +78,4 @@ from .buffer_sample import _buffer_sample_cpu
from .priority_replay_buffer import _prb_push_op_cpu
from .priority_replay_buffer import _prb_sample_op_cpu
from .space_to_batch_nd import _space_to_batch_nd_cpu
from .sspaddmm import _sspaddmm_cpu

View File

@ -0,0 +1,95 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sspaddmm op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
sspaddmm_op_info = CpuRegOp("Sspaddmm") \
.input(0, "x1_indices", "required") \
.input(1, "x1_values", "required") \
.input(2, "x1_shape", "required") \
.input(3, "x2_indices", "required") \
.input(4, "x2_values", "required") \
.input(5, "x2_shape", "required") \
.input(6, "x3_dense", "required") \
.input(7, "alpha", "required") \
.input(8, "beta", "required") \
.output(0, "y_indices", "required") \
.output(1, "y_values", "required") \
.output(2, "y_shape", "required")
def _reg_format(op_info, dtype, alpha, beta):
return op_info.dtype_format(DataType.I32_Default, dtype, DataType.I32_Default, DataType.I32_Default, dtype,
DataType.I32_Default, dtype, alpha, beta, DataType.I64_Default, dtype,
DataType.I64_Default)
def _reg_format_beta(op_info, alpha, beta):
op_info = _reg_format(op_info, DataType.U8_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.I8_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.I16_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.I32_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.I64_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.F32_Default, alpha, beta)
op_info = _reg_format(op_info, DataType.F64_Default, alpha, beta)
return op_info
def _reg_format_alpha(op_info, alpha):
"""alpha reg"""
op_info = _reg_format_beta(op_info, alpha, DataType.U8_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.U16_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.U32_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.U64_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.I8_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.I16_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.I32_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.I64_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.F16_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.F32_Default)
op_info = _reg_format_beta(op_info, alpha, DataType.F64_Default)
return op_info
def _reg_format_indices(op_info):
"""indices reg"""
op_info = _reg_format_alpha(op_info, DataType.U8_Default)
op_info = _reg_format_alpha(op_info, DataType.U16_Default)
op_info = _reg_format_alpha(op_info, DataType.U32_Default)
op_info = _reg_format_alpha(op_info, DataType.U64_Default)
op_info = _reg_format_alpha(op_info, DataType.I8_Default)
op_info = _reg_format_alpha(op_info, DataType.I16_Default)
op_info = _reg_format_alpha(op_info, DataType.I32_Default)
op_info = _reg_format_alpha(op_info, DataType.I64_Default)
op_info = _reg_format_alpha(op_info, DataType.F16_Default)
op_info = _reg_format_alpha(op_info, DataType.F32_Default)
op_info = _reg_format_alpha(op_info, DataType.F64_Default)
return op_info
def _reg_format_indices_shape(op_info):
"""shape reg"""
op_info = _reg_format_indices(op_info)
return op_info.get_op_info()
sspaddmm_op_info_all = _reg_format_indices_shape(sspaddmm_op_info)
@op_info_register(sspaddmm_op_info_all)
def _sspaddmm_cpu():
"""Sspaddmm cpu register"""
return

View File

@ -379,3 +379,103 @@ class DenseToDenseSetOperation(Primitive):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y_indices', 'y_values', 'y_shape'])
validator.check_value_type("set_operation", set_operation, [str], self.name)
validator.check_value_type("validate_indices", validate_indices, [bool], self.name)
class Sspaddmm(Primitive):
r"""
Matrix multiplies a sparse tensor `x2` with a dense tensor `x3`, then adds the sparse tensor `x1`.
If `x1_shape` is :math:`(s0, s1)`, `x2_shpae` should be :math:`(s0, s2)`, the `x3_shape` should be :math:`(s2, s1)`.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
.. math::
out =\beta * x1 + \alpha * (x2 @ x3),
Inputs:
- **x1_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64. The shape is :math:`(2, n)`. If `x1_shape` is :math:`(s0, s1)`, the row index
value of `x1_indices` should be a non-negative and less than `s0` int number, the col index value of
`x1_indices` should be a non-negative and less than `s1` int number.
- **x1_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in
the `x1_indices`. Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as
`x2_values` and `x3_dense`. The shape should be :math:`(n,)`.
- **x1_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32, int64,
have 2 positive int elements, shape is :math:`(2,)`. The dtype should be the same as `x1_indices`.
- **x2_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
Support int32, int64. The shape is :math:`(2, n)`. If `x2_shape` is :math:`(s0, s2)`, the row index
value of `x2_indices` should be a non-negative and less than `s0` int number, the col index value of
`x2_indices` should be a non-negative and less than `s2` int number.
- **x2_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `x2_indices`.
Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as `x1_values`
and `x3_dense`. The shape should be :math:`(n,)`.
- **x2_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32,int64,
have 2 positive int elements, shape is :math:`(2,)`. The dtype is same as `x2_indices`.
- **x3_dense** (Tensor) - A 2-D Tensor, the dtype should be the same as `x2_values` and `x3_dense`.
- **alpha** (Tensor) - A 0-D or 1-D Tensor, the weight of x1. If alpha is 1-D tensor,
the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64,
int8, int16, int32, int64, float16, float32, float64. If the dtype of alpha is not the same with expected
output dtype, alpha value should be convert without overflow.
- **beta** (Tensor) - A 0-D or 1-D, the weight of x2@x3. If alpha is 1-D tensor,
the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64,
int8, int16, int32, int64, float16, float32, float64. If the `x1_values` dtype is byte, char, short, int,
long, the dtype of beta doesn't support float16, float32, float64.
Outputs:
- **y_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
The dtype is int64, each element value should be a non-negative int number. The shape is :math:`(2, n)`.
- **y_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `y_indices`.
The dtype is the same as `x1_values` . The shape should be :math:`(n,)`.
- **y_shape** (Tensor) - A 1-D Tensor, A positive int tuple which specifies the shape of sparse tensor.
The dtype is int64, the values is the same as `x1_shape`.
Raises:
TypeError: If dtype of `x1_indices`, `x1_shape` is not the same and neither int32 nor int64.
TypeError: If dtype of `x2_indices`, `x2_shape` is not the same and not int32 or int64.
TypeError: If type of `x1_values`, `x2_values`, `x3_dense` is not the same.
TypeError: If dtype of `x1_values`, `x2_values`, `x3_dense` is not uint8, int8, int16, int32, int64, float32,
float64.
ValueError: If shape of `x1_indices`, `x2_indices` is not (2, n).
ValueError: If shape of `x1_values`, `x2_values` is not (n,).
ValueError: If dim0 size of `x1_values` is not the same with dim1 size of `x1_indices`.
ValueError: If dim0 size of `x2_values` is not the same with dim1 size of `x2_indices`.
ValueError: If shape of `x1_shape` or shape of `x2_shape` is not (2,).
ValueError: If dim of `x3_dense` is not 2D.
ValueError: If dtype of `alpha` is not the same with `x2_values` dtype, and alpha value convert to the
`x2_values` dtype overflow.
TypeError: If dtype of `alpha`, `beta` is not uint8, uint16, uint32, uint64, int8, int16, int32, int64,
float16, float32, float64.
TypeError: If the `x1_values` dtype is byte, char, short, int, long, while the dtype of beta is float16,
float32 or float64.
ValueError: If the shape of `alpha`, `beta` is not () or (1,).
Supported Platforms:
``Ascend`` ``CPU``
Examples:
>>> x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int64)
>>> x1_values = Tensor(np.array([1, 2]), mstype.int32)
>>> x1_shape = Tensor(np.array([3, 3]), mstype.int64)
>>> x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int64)
>>> x2_values = Tensor(np.array([3, 4]), mstype.int32)
>>> x2_shape = Tensor(np.array([3, 3]), mstype.int64)
>>> x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), mstype.int32)
>>> alpha = Tensor(np.array(1), mstype.int32)
>>> beta = Tensor(np.array(1), mstype.int32)
>>> sspaddmm = ops.Sspaddmm()
>>> out_indices, out_values, out_shapes = sspaddmm(x1_indices, x1_values, x1_shape,
... x2_indices, x2_values, x2_shape, x3_dense, alpha, beta)
>>> print(out_indices)
[[0 1 0 0 0 1 1 1]
[0 1 0 1 2 0 1 2]]
>>> print(out_values)
[ 1 2 9 6 3 12 8 4]
>>> print(out_shapes)
[3 3]
"""
@prim_attr_register
def __init__(self):
"""Initialize Sspaddmm."""
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2_indices', 'x2_values', 'x2_shape',
'x3_dense', 'alpha', 'beta'], outputs=['y_indices', 'y_values', 'y_shape'])

View File

@ -225,6 +225,11 @@ class InputOpNet(nn.Cell):
return x
def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9)
return x
def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
if isinstance(op, nn.Cell):
return op

View File

@ -73,7 +73,7 @@ from mindspore.ops.operations.nn_ops import MaxPoolV1
from mindspore.ops.operations.array_ops import NonZero
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
from mindspore.ops.operations.nn_ops import ReLUV3
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm
from mindspore.ops.operations.other_ops import BlackmanWindow
from mindspore.nn.layer import normalization
from mindspore.ops.operations.array_ops import RightShift
@ -1605,6 +1605,18 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.array([[1, 0], [0, 1]]).astype(np.float32)),
Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))],
'skip': ['backward']}),
('Sspaddmm', {
'block': Sspaddmm(),
'desc_inputs': [Tensor(np.array([[0, 1], [0, 1]]).astype(np.int64)),
Tensor(np.array([1, 2]).astype(np.int32)),
Tensor(np.array([3, 3]).astype(np.int64)),
Tensor(np.array([[0, 1], [2, 2]]).astype(np.int64)),
Tensor(np.array([3, 4]).astype(np.int32)),
Tensor(np.array([3, 3]).astype(np.int64)),
Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]).astype(np.int32)),
Tensor(np.array(1).astype(np.int32)),
Tensor(np.array(1).astype(np.int32))],
'skip': ['backward']}),
('Embedding_1', {
'block': Embedding(vocab_size=10, embedding_size=3),
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],