update aicpu operator sspaddmm
This commit is contained in:
parent
7f56bf8e60
commit
16aad8be1b
|
@ -0,0 +1,371 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/sspaddmm_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 9;
|
||||
constexpr size_t kOutputsNum = 3;
|
||||
constexpr char kKernelName[] = "Sspaddmm";
|
||||
|
||||
#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \
|
||||
if (dtype1 == kNumberTypeInt32) { \
|
||||
CheckSparseIndicesLegal<int32_t, int32_t>(indices, shapes, num, x_name); \
|
||||
} else { \
|
||||
CheckSparseIndicesLegal<int64_t, int64_t>(indices, shapes, num, x_name); \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kKernelName);
|
||||
}
|
||||
|
||||
void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
CheckParam(kernel_node);
|
||||
|
||||
output_values_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
|
||||
input_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex0);
|
||||
input_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex2);
|
||||
mat1_indices_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex3);
|
||||
mat1_shape_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex5);
|
||||
alpha_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex7);
|
||||
beta_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex8);
|
||||
|
||||
auto input_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex0);
|
||||
auto mat1_indices_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex3);
|
||||
auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6);
|
||||
|
||||
input_values_num_ = input_indices_shape[1];
|
||||
mat1_values_num_ = mat1_indices_shape[1];
|
||||
y_values_num_ = y_indices_shape[1];
|
||||
mat2_row_ = mat2_shape[0];
|
||||
mat2_col_ = mat2_shape[1];
|
||||
}
|
||||
|
||||
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
switch (output_values_dtype_) {
|
||||
case kNumberTypeUInt8: {
|
||||
LaunchKernel<uint8_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt8: {
|
||||
LaunchKernel<int8_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt16: {
|
||||
LaunchKernel<int16_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt32: {
|
||||
LaunchKernel<int32_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt64: {
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat32: {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat64: {
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_EXCEPTION(TypeError) << "For Sspaddmm, The output dtype error.";
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input_indices_addr = inputs[0]->addr;
|
||||
auto input_values_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto input_shape_addr = inputs[2]->addr;
|
||||
auto mat1_indices_addr = inputs[3]->addr;
|
||||
auto mat1_values_addr = reinterpret_cast<T *>(inputs[4]->addr);
|
||||
auto mat1_shape_addr = inputs[5]->addr;
|
||||
auto mat2_addr = reinterpret_cast<T *>(inputs[6]->addr);
|
||||
auto alpha_val_addr = inputs[7]->addr;
|
||||
auto beta_val_addr = inputs[8]->addr;
|
||||
auto y_indices_addr = reinterpret_cast<int64_t *>(outputs[0]->addr);
|
||||
auto y_values_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
auto y_shape_addr = reinterpret_cast<int64_t *>(outputs[2]->addr);
|
||||
CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_,
|
||||
"x1");
|
||||
CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_,
|
||||
"x2");
|
||||
|
||||
int64_t mat1_row, mat1_col, input_row, input_col;
|
||||
if (mat1_shape_dtype_ == kNumberTypeInt32) {
|
||||
auto mat1_shape_val = reinterpret_cast<int32_t *>(mat1_shape_addr);
|
||||
mat1_row = static_cast<int64_t>(mat1_shape_val[0]);
|
||||
mat1_col = static_cast<int64_t>(mat1_shape_val[1]);
|
||||
} else {
|
||||
auto mat1_shape_val = reinterpret_cast<int64_t *>(mat1_shape_addr);
|
||||
mat1_row = mat1_shape_val[0];
|
||||
mat1_col = mat1_shape_val[1];
|
||||
}
|
||||
if (input_shape_dtype_ == kNumberTypeInt32) {
|
||||
auto input_shape_val = reinterpret_cast<int32_t *>(input_shape_addr);
|
||||
input_row = static_cast<int64_t>(input_shape_val[0]);
|
||||
input_col = static_cast<int64_t>(input_shape_val[1]);
|
||||
} else {
|
||||
auto input_shape_val = reinterpret_cast<int64_t *>(input_shape_addr);
|
||||
input_row = input_shape_val[0];
|
||||
input_col = input_shape_val[1];
|
||||
}
|
||||
|
||||
if (mat1_col != static_cast<int64_t>(mat2_row_)) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse tensor x2's dense shape col:" << mat1_col
|
||||
<< " should be equal to x3_dense shape row:" << mat2_row_ << ".";
|
||||
}
|
||||
if (mat1_row != input_row || static_cast<int64_t>(mat2_col_) != input_col) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, the sparse x1's dense shape "
|
||||
<< "[" << input_row << "," << input_col << "] should equal to x2@x3_dense shape ["
|
||||
<< mat1_row << "," << mat2_col_ << "].";
|
||||
}
|
||||
|
||||
if (input_shape_dtype_ == kNumberTypeInt32) {
|
||||
InitShape<int32_t>(input_shape_addr, y_shape_addr);
|
||||
} else {
|
||||
InitShape<int64_t>(input_shape_addr, y_shape_addr);
|
||||
}
|
||||
|
||||
ClearSparseValues<T>(y_values_addr, y_values_num_);
|
||||
|
||||
// scalar * sparse inplace
|
||||
T *input_values_addr_bak = ScalarSparseMul<T>(input_values_addr, beta_val_addr, input_values_num_, beta_dtype_);
|
||||
T *mat1_values_addr_bak = ScalarSparseMul<T>(mat1_values_addr, alpha_val_addr, mat1_values_num_, alpha_dtype_);
|
||||
|
||||
// sparse + sparse
|
||||
if (input_indices_dtype_ == kNumberTypeInt32) {
|
||||
SparseAddSparse<int32_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
|
||||
y_values_addr, y_values_num_);
|
||||
} else {
|
||||
SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
|
||||
y_values_addr, y_values_num_);
|
||||
}
|
||||
auto row = mat1_col;
|
||||
auto col = input_col;
|
||||
if (mat1_indices_dtype_ == kNumberTypeInt32) {
|
||||
SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
|
||||
y_values_addr, y_values_num_, row, col);
|
||||
} else {
|
||||
SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
|
||||
y_values_addr, y_values_num_, row, col);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num,
|
||||
std::string x_name) {
|
||||
auto indices_val = reinterpret_cast<T *>(indices_addr);
|
||||
auto shape_val = reinterpret_cast<S *>(shape_addr);
|
||||
int shape_num = 2;
|
||||
for (int i = 0; i < shape_num; i++) {
|
||||
if (shape_val[i] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_value should be positive"
|
||||
<< " while get shape [" << shape_val[0] << ", " << shape_val[1] << "]";
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
int64_t row = static_cast<int64_t>(shape_val[0]);
|
||||
int64_t col = static_cast<int64_t>(shape_val[1]);
|
||||
int64_t indices_row = static_cast<int64_t>(indices_val[i]);
|
||||
int64_t indices_col = static_cast<int64_t>(indices_val[i + num]);
|
||||
if ((indices_row >= row) || indices_col >= col || indices_row < 0 || indices_col < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, the " << x_name << "_indices"
|
||||
<< " row value:" << indices_row << ", col value: " << indices_col << " out of bounds."
|
||||
<< " Row should between [0," << row << "].Col should between [0," << col << "].";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) {
|
||||
auto input_shape_val = reinterpret_cast<T *>(input_shape);
|
||||
size_t shape_num = 2;
|
||||
for (size_t i = 0; i < shape_num; i++) {
|
||||
y_shape[i] = static_cast<int64_t>(input_shape_val[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
sparse_val[i] = static_cast<T>(0);
|
||||
}
|
||||
};
|
||||
|
||||
ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
// scalar * sparse matrix for beta * input alpha * mat1
|
||||
template <typename T>
|
||||
T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) {
|
||||
T val;
|
||||
if (!(data_num > 0)) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. ";
|
||||
}
|
||||
T *sparse_val_bak = new T[data_num];
|
||||
switch (tid) {
|
||||
case kNumberTypeUInt8:
|
||||
val = static_cast<T>(reinterpret_cast<uint8_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt16:
|
||||
val = static_cast<T>(reinterpret_cast<uint16_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt32:
|
||||
val = static_cast<T>(reinterpret_cast<uint32_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt64:
|
||||
val = static_cast<T>(reinterpret_cast<uint64_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
val = static_cast<T>(reinterpret_cast<int8_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
val = static_cast<T>(reinterpret_cast<int16_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
val = static_cast<T>(reinterpret_cast<int32_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
val = static_cast<T>(reinterpret_cast<int64_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
val = static_cast<T>(reinterpret_cast<float16 *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
val = static_cast<T>(reinterpret_cast<float *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
val = static_cast<T>(reinterpret_cast<double *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeBool:
|
||||
val = static_cast<T>(reinterpret_cast<bool *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeComplex64:
|
||||
val = static_cast<T>(reinterpret_cast<std::complex<float> *>(scalar_val)[0].real());
|
||||
break;
|
||||
case kNumberTypeComplex128:
|
||||
val = static_cast<T>(reinterpret_cast<std::complex<double> *>(scalar_val)[0].real());
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. ";
|
||||
break;
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
sparse_val_bak[i] = sparse_val[i] * val;
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, data_num, this, ¶llel_search_info_);
|
||||
return sparse_val_bak;
|
||||
}
|
||||
|
||||
// sparse matrix add sparse matrix
|
||||
// input + mat1 @ mat2
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices,
|
||||
S *y_values, size_t y_num) {
|
||||
// to implement m1[row][col] = vals
|
||||
auto input_ids = reinterpret_cast<T *>(input_indices);
|
||||
this->cnt_ = input_num;
|
||||
// get output vals and index addr
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
auto row = input_ids[i];
|
||||
auto col = input_ids[i + input_num];
|
||||
y_values[i] = input_values[i];
|
||||
y_indices[i] = static_cast<int64_t>(row);
|
||||
y_indices[i + y_num] = static_cast<int64_t>(col);
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, input_num, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr,
|
||||
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row,
|
||||
int64_t mat2_col_) {
|
||||
// the result of mat1 @ mat2 will write to output directly
|
||||
auto mat1_ids = reinterpret_cast<T *>(mat1_indices);
|
||||
|
||||
std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt;
|
||||
std::unordered_map<T, std::vector<T>> unrepeated;
|
||||
std::unordered_map<T, std::unordered_map<T, std::vector<S>>> co_map_idx;
|
||||
|
||||
// unrepeated : [1 -> [0], 2 -> [1, 2]]
|
||||
// co_map_idx : [1][0] -> 0.3
|
||||
for (size_t i = 0; i < mat1_vals_num; i++) {
|
||||
T _row = mat1_ids[i];
|
||||
T _col = mat1_ids[i + mat1_vals_num];
|
||||
unrepeated[_row].push_back(_col);
|
||||
co_map_idx[_row][_col].push_back(mat1_values[i]);
|
||||
for (int64_t j = 0; j < mat2_col_; j++) {
|
||||
if (idx_map_cnt[_row][j] == 0) {
|
||||
idx_map_cnt[_row][j] = this->cnt_;
|
||||
this->cnt_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<T> res;
|
||||
for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) {
|
||||
res.push_back(it->first);
|
||||
}
|
||||
|
||||
size_t n_unreapeat = unrepeated.size();
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
// get val
|
||||
auto row_mat1 = res[i];
|
||||
for (auto row_mat2 : unrepeated[row_mat1]) {
|
||||
S val = co_map_idx[row_mat1][row_mat2].back();
|
||||
co_map_idx[row_mat1][row_mat2].pop_back();
|
||||
for (int64_t j = 0; j < mat2_col_; j++) {
|
||||
// get val
|
||||
T idx = idx_map_cnt[row_mat1][j];
|
||||
*(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j];
|
||||
y_indices[idx] = static_cast<int64_t>(row_mat1);
|
||||
y_indices[idx + y_vals_num] = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, n_unreapeat, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Sspaddmm, SspaddmmCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
SspaddmmCPUKernelMod() = default;
|
||||
~SspaddmmCPUKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
template <typename T, typename S>
|
||||
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name);
|
||||
template <typename T>
|
||||
void InitShape(void *input_shape, int64_t *y_shape);
|
||||
template <typename T>
|
||||
void ClearSparseValues(T *sparse_val, size_t data_num);
|
||||
template <typename T>
|
||||
T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid);
|
||||
template <typename T, typename S>
|
||||
void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
|
||||
size_t y_num);
|
||||
template <typename T, typename S>
|
||||
void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor,
|
||||
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col);
|
||||
|
||||
TypeId output_values_dtype_{kTypeUnknown};
|
||||
TypeId input_indices_dtype_{kTypeUnknown};
|
||||
TypeId input_shape_dtype_{kTypeUnknown};
|
||||
TypeId mat1_indices_dtype_{kTypeUnknown};
|
||||
TypeId mat1_shape_dtype_{kTypeUnknown};
|
||||
TypeId alpha_dtype_{kTypeUnknown};
|
||||
TypeId beta_dtype_{kTypeUnknown};
|
||||
size_t input_values_num_{0};
|
||||
size_t mat1_values_num_{0};
|
||||
size_t y_values_num_{0};
|
||||
size_t cnt_{0};
|
||||
size_t mat2_row_{0};
|
||||
size_t mat2_col_{0};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SSPADDMM_CPU_KERNEL_H_
|
|
@ -91,6 +91,7 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
static const auto &kSegmentSum = prim::kPrimSegmentSum->name();
|
||||
static const auto &kBlackmanWindow = prim::kPrimBlackmanWindow->name();
|
||||
static const auto &kExpand = prim::kPrimExpand->name();
|
||||
static const auto &kSspaddmm = prim::kPrimSspaddmm->name();
|
||||
// Common host depends.
|
||||
static PrimShapeDependMap host_depends{{kSegmentMax, ShapeSet{1}},
|
||||
{kSegmentMin, ShapeSet{1}},
|
||||
|
@ -132,7 +133,8 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
{kSparseSegmentMean, ShapeSet{2}},
|
||||
{kResizeLinear1D, ShapeSet{1}},
|
||||
{kBlackmanWindow, ShapeSet{0}},
|
||||
{kExpand, ShapeSet{1}}};
|
||||
{kExpand, ShapeSet{1}},
|
||||
{kSspaddmm, ShapeSet{0, 2, 3, 5, 7}}};
|
||||
return host_depends;
|
||||
}
|
||||
|
||||
|
|
|
@ -340,6 +340,7 @@ GVAR_DEF(PrimitivePtr, kPrimGatherNd, std::make_shared<Primitive>("GatherNd"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimSparseGatherV2, std::make_shared<Primitive>("SparseGatherV2"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimCoalesce, std::make_shared<Primitive>(kCoalesce));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSparseToDense, std::make_shared<Primitive>("SparseToDense"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSspaddmm, std::make_shared<Primitive>("Sspaddmm"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimShape, std::make_shared<Primitive>("Shape"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStridedSlice, std::make_shared<Primitive>(kStridedSlice));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStridedSliceGrad, std::make_shared<Primitive>(kStridedSliceGrad));
|
||||
|
|
|
@ -0,0 +1,484 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <climits>
|
||||
#include "ops/sspaddmm.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
const int64_t MAX_LEN = 1000000;
|
||||
|
||||
int64_t compute_output_indices_unique_size_int32(int32_t *indices, size_t size) {
|
||||
std::set<int32_t> mat1_indices_set;
|
||||
size_t half_size = size / 2;
|
||||
for (size_t i = 0; i < half_size; i++) {
|
||||
mat1_indices_set.insert(indices[i]);
|
||||
}
|
||||
return mat1_indices_set.size();
|
||||
}
|
||||
|
||||
int64_t compute_output_indices_unique_size_int64(int64_t *indices, size_t size) {
|
||||
std::set<int64_t> mat1_indices_set;
|
||||
size_t half_size = size / 2;
|
||||
for (size_t i = 0; i < half_size; i++) {
|
||||
mat1_indices_set.insert(indices[i]);
|
||||
}
|
||||
return mat1_indices_set.size();
|
||||
}
|
||||
|
||||
enum DimNum : size_t {
|
||||
dim0Num = 0,
|
||||
dim1Num,
|
||||
dim2Num,
|
||||
};
|
||||
|
||||
int64_t GetInt64AlphaDataOther(void *values, TypeId tid, TypePtr expect_dtype, float real) {
|
||||
int64_t compute_val = 0;
|
||||
bool flag1 = expect_dtype->type_id() == kNumberTypeUInt8 ? true : false;
|
||||
bool flag2 = false;
|
||||
switch (tid) {
|
||||
case kNumberTypeFloat16:
|
||||
flag2 = (flag1 && (reinterpret_cast<float16 *>(values)[0] < static_cast<float16>(0))) ? true : false;
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<float16 *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
flag2 = (flag1 && (reinterpret_cast<float *>(values)[0] < 0)) ? true : false;
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<float *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
flag2 = (flag1 && (reinterpret_cast<double *>(values)[0] < 0)) ? true : false;
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<double *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeBool:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<bool *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeComplex64:
|
||||
case kNumberTypeComplex128:
|
||||
compute_val = static_cast<int64_t>(real);
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For Sspaddmm, alpha dtype is not support, only support"
|
||||
<< " number type and bool, complex64, complex128. ";
|
||||
break;
|
||||
}
|
||||
if (flag2) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cannot be converted to type uint8 without overflow. ";
|
||||
}
|
||||
return compute_val;
|
||||
}
|
||||
|
||||
int64_t GetInt64AlphaData(void *values, TypeId tid, TypePtr expect_dtype, float real) {
|
||||
int64_t compute_val = 0;
|
||||
switch (tid) {
|
||||
case kNumberTypeUInt8:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<uint8_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt16:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<uint16_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt32:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<uint32_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt64:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<uint64_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<int8_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<int16_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<int32_t *>(values)[0]);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
compute_val = static_cast<int64_t>(reinterpret_cast<int64_t *>(values)[0]);
|
||||
break;
|
||||
default:
|
||||
compute_val = GetInt64AlphaDataOther(values, tid, expect_dtype, real);
|
||||
break;
|
||||
}
|
||||
return compute_val;
|
||||
}
|
||||
|
||||
void CheckAlphaBeta(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto alpha_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex7]->BuildShape())[kShape];
|
||||
auto beta_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex8]->BuildShape())[kShape];
|
||||
if (!(alpha_shape.size() == dim1Num && alpha_shape[0] == dim1Num)) {
|
||||
if (alpha_shape.size() != dim0Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha shape should be (1,) or ()"
|
||||
<< ", but get dim num is " << alpha_shape.size() << ", dim0 size is " << alpha_shape[0]
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
if (!(beta_shape.size() == dim1Num && beta_shape[0] == dim1Num)) {
|
||||
if (beta_shape.size() != dim0Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, beta shape should be (1,) or ()"
|
||||
<< ", but get dim num is " << beta_shape.size() << ", dim0 size is " << beta_shape[0]
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckInputTensor(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x1_indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto x1_values_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
|
||||
auto x1_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
|
||||
auto x2_indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
|
||||
auto x2_values_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
|
||||
auto x2_shape_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex5]->BuildShape())[kShape];
|
||||
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
||||
|
||||
if (x1_indices_shape.size() != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices should be a 2-D tensor"
|
||||
<< ", while x1_indices dim num is " << x1_indices_shape.size() << ".";
|
||||
}
|
||||
if (x1_indices_shape[0] != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_indices shape should be (2, n)"
|
||||
<< ", while x1_indices shape dim0 is " << x1_indices_shape[0] << ".";
|
||||
}
|
||||
if (x1_values_shape.size() != dim1Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x1_values should be a 1-D tensor"
|
||||
<< ", while x1_values dim num is " << x1_values_shape.size() << ".";
|
||||
}
|
||||
if (x1_indices_shape[1] != x1_values_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", dim1 size of `x1_indices` and dim0 size of `x1_values` should be the same"
|
||||
<< " while x1_indices dim1 size is " << x1_indices_shape[1]
|
||||
<< ", x1_values_shape dim0 size is " << x1_values_shape[0] << ".";
|
||||
}
|
||||
if (x1_shape_shape.size() != dim1Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", x1_shape should be a 1-D tensor, while x1_shape dim num is " << x1_shape_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
if (x1_shape_shape[0] != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", the shape of x1_shape should be [2] but got shape [" << x1_shape_shape[0] << "].";
|
||||
}
|
||||
if (x2_indices_shape.size() != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices should be a 2-D tensor"
|
||||
<< ", while x2_indices dim num is " << x2_indices_shape.size() << ".";
|
||||
}
|
||||
if (x2_indices_shape[0] != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_indices shape should be (2, n)"
|
||||
<< ", while x2_indices shape dim0 is " << x2_indices_shape[0] << ".";
|
||||
}
|
||||
if (x2_values_shape.size() != dim1Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x2_values should be a 1-D tensor"
|
||||
<< ", while x2_values dim num is " << x2_values_shape.size() << ".";
|
||||
}
|
||||
if (x2_indices_shape[1] != x2_values_shape[0]) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", dim1 size of `x2_indices` and dim0 size of `x2_values` should be the same"
|
||||
<< " while x2_indices dim1 size is " << x2_indices_shape[1]
|
||||
<< ", x2_values_shape dim0 size is " << x2_values_shape[0] << ".";
|
||||
}
|
||||
if (x2_shape_shape.size() != dim1Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", x2_shape should be a 1-D tensor, while x2_shape dim num is " << x2_shape_shape.size()
|
||||
<< ".";
|
||||
}
|
||||
if (x2_shape_shape[0] != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", the shape of x2_shape should be [2] but got shape [" << x2_shape_shape[0] << "].";
|
||||
}
|
||||
if (x3_shape.size() != dim2Num) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, x3_dense should be a 2-D tensor"
|
||||
<< ", while dim num is " << x3_shape.size() << ".";
|
||||
}
|
||||
CheckAlphaBeta(input_args);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IndicesBoundCheck(T *indices_val, size_t indices_num, T *shape_val, std::string name) {
|
||||
if (shape_val[0] <= 0 || shape_val[1] <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_shape should be positive, "
|
||||
<< "while got shape [" << shape_val[0] << ", " << shape_val[1] << "].";
|
||||
}
|
||||
size_t half_num = indices_num / dim2Num;
|
||||
for (size_t i = 0; i < half_num; i++) {
|
||||
if ((indices_val[i] < 0) || (indices_val[i] >= shape_val[0])) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices row index should between [0, " << shape_val[0]
|
||||
<< "], while got row index " << indices_val[i] << ".";
|
||||
}
|
||||
if ((indices_val[i + half_num] < 0) || (indices_val[i + half_num] >= shape_val[1])) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, " << name << "_indices col index should between [0, " << shape_val[1]
|
||||
<< "], while got col index " << indices_val[i + half_num] << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckIndices(const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto x1_indices_abstract = input_args[kInputIndex0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_abstract);
|
||||
auto x1_indices_value_ptr = x1_indices_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_value_ptr);
|
||||
auto x1_indices_tensor = x1_indices_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_tensor);
|
||||
auto x1_indices_type = input_args[kInputIndex0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_type);
|
||||
auto x1_indices_type_id = x1_indices_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_type_id);
|
||||
auto x1_indices_type_element = x1_indices_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(x1_indices_type_element);
|
||||
auto x1_shape_abstract = input_args[kInputIndex2]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x1_shape_abstract);
|
||||
auto x1_shape_value_ptr = x1_shape_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(x1_shape_value_ptr);
|
||||
auto x1_shape_tensor = x1_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x1_shape_tensor);
|
||||
if (x1_indices_type_element->type_id() == kNumberTypeInt32) {
|
||||
IndicesBoundCheck<int32_t>(reinterpret_cast<int32_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
|
||||
reinterpret_cast<int32_t *>(x1_shape_tensor->data_c()), "x1");
|
||||
} else {
|
||||
IndicesBoundCheck<int64_t>(reinterpret_cast<int64_t *>(x1_indices_tensor->data_c()), x1_indices_tensor->DataSize(),
|
||||
reinterpret_cast<int64_t *>(x1_shape_tensor->data_c()), "x1");
|
||||
}
|
||||
auto x2_indices_abstract = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_abstract);
|
||||
auto x2_indices_value_ptr = x2_indices_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_value_ptr);
|
||||
auto x2_indices_tensor = x2_indices_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_tensor);
|
||||
auto x2_indices_type = input_args[kInputIndex3]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type);
|
||||
auto x2_indices_type_id = x2_indices_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type_id);
|
||||
auto x2_indices_type_element = x2_indices_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type_element);
|
||||
auto x2_shape_abstract = input_args[kInputIndex5]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_shape_abstract);
|
||||
auto x2_shape_value_ptr = x2_shape_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(x2_shape_value_ptr);
|
||||
auto x2_shape_tensor = x2_shape_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_shape_tensor);
|
||||
if (x2_indices_type_element->type_id() == kNumberTypeInt32) {
|
||||
IndicesBoundCheck<int32_t>(reinterpret_cast<int32_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(),
|
||||
reinterpret_cast<int32_t *>(x2_shape_tensor->data_c()), "x2");
|
||||
} else {
|
||||
IndicesBoundCheck<int64_t>(reinterpret_cast<int64_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize(),
|
||||
reinterpret_cast<int64_t *>(x2_shape_tensor->data_c()), "x2");
|
||||
}
|
||||
}
|
||||
|
||||
bool GetDtypeMinAndMaxAndCheckOverFlow(TypePtr tid, int64_t compute_val) {
|
||||
int64_t min = 0;
|
||||
int64_t max = 0;
|
||||
switch (tid->type_id()) {
|
||||
case kNumberTypeUInt8:
|
||||
max = UCHAR_MAX;
|
||||
min = -UCHAR_MAX - 1;
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
max = SCHAR_MAX;
|
||||
min = SCHAR_MIN;
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
max = SHRT_MAX;
|
||||
min = SHRT_MIN;
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
max = INT_MAX;
|
||||
min = INT_MIN;
|
||||
break;
|
||||
default:
|
||||
max = LONG_MAX;
|
||||
min = LONG_MIN;
|
||||
break;
|
||||
}
|
||||
if (compute_val <= min || compute_val > max) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void PrintAlphaValueError(TypeId aid, TypePtr expect_dtype, int64_t compute_val, float real, int64_t imag) {
|
||||
if (aid == kNumberTypeComplex64 || aid == kNumberTypeComplex128) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", alpha cannot be converted to expect dtype " << expect_dtype->ToString()
|
||||
<< ", without overflow: (" << real << ", " << imag << ").";
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm"
|
||||
<< ", alpha cannot be converted to expect x2_values dtype " << expect_dtype->ToString()
|
||||
<< ", without overflow: " << compute_val << ".";
|
||||
}
|
||||
}
|
||||
|
||||
abstract::TupleShapePtr SspaddmmInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
if (input_args[kInputIndex3]->isa<abstract::AbstractTensor>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex3]->BuildValue()->isa<None>()) {
|
||||
if (input_args[kInputIndex7]->isa<abstract::AbstractTensor>() &&
|
||||
!input_args[kInputIndex7]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex7]->BuildValue()->isa<None>()) {
|
||||
CheckIndices(input_args);
|
||||
auto alpha_abstract = input_args[kInputIndex7]->cast<abstract::AbstractTensorPtr>();
|
||||
auto alpha_value_ptr = alpha_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(alpha_value_ptr);
|
||||
auto alpha_tensor = alpha_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(alpha_tensor);
|
||||
auto alpha_dtype = input_args[kInputIndex7]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(alpha_dtype);
|
||||
auto alpha_type_id = alpha_dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(alpha_type_id);
|
||||
auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto alpha_type_element = alpha_type_id->element();
|
||||
float real = 0;
|
||||
int32_t imag = 0;
|
||||
if (alpha_type_element->type_id() == kNumberTypeComplex64) {
|
||||
auto value = reinterpret_cast<std::complex<float> *>(alpha_tensor->data_c());
|
||||
real = value[0].real();
|
||||
imag = value[0].imag();
|
||||
} else if (alpha_type_element->type_id() == kNumberTypeComplex128) {
|
||||
auto value = reinterpret_cast<std::complex<double> *>(alpha_tensor->data_c());
|
||||
real = value[0].real();
|
||||
imag = value[0].imag();
|
||||
}
|
||||
if (imag != 0 || (expect_dtype->type_id() == kNumberTypeUInt8 && real < 0)) {
|
||||
MS_EXCEPTION(ValueError) << "For " << op_name
|
||||
<< ", alpha value cannot be converted to type uint8 , without overflow: (" << real
|
||||
<< ", " << imag << ").";
|
||||
}
|
||||
if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) {
|
||||
int64_t compute_val =
|
||||
GetInt64AlphaData(alpha_tensor->data_c(), alpha_type_element->type_id(), expect_dtype, real);
|
||||
if (GetDtypeMinAndMaxAndCheckOverFlow(expect_dtype, compute_val)) {
|
||||
PrintAlphaValueError(alpha_type_element->type_id(), expect_dtype, compute_val, real, imag);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, alpha value cant't get.";
|
||||
}
|
||||
auto x2_indices_abstract = input_args[kInputIndex3]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_abstract);
|
||||
auto x2_indices_value_ptr = x2_indices_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_value_ptr);
|
||||
auto x2_indices_tensor = x2_indices_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_tensor);
|
||||
auto x2_indices_type = input_args[kInputIndex3]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type);
|
||||
auto x2_indices_type_id = x2_indices_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type_id);
|
||||
auto x2_indices_type_element = x2_indices_type_id->element();
|
||||
MS_EXCEPTION_IF_NULL(x2_indices_type_element);
|
||||
int64_t x2_indices_unique_size = 0;
|
||||
if (x2_indices_type_element->type_id() == kNumberTypeInt32) {
|
||||
x2_indices_unique_size = compute_output_indices_unique_size_int32(
|
||||
reinterpret_cast<int32_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize());
|
||||
} else {
|
||||
x2_indices_unique_size = compute_output_indices_unique_size_int64(
|
||||
reinterpret_cast<int64_t *>(x2_indices_tensor->data_c()), x2_indices_tensor->DataSize());
|
||||
}
|
||||
auto x1_indices_shape =
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto x3_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex6]->BuildShape())[kShape];
|
||||
int64_t x2_indices_shape_right = x2_indices_unique_size * x3_shape[1] + x1_indices_shape[1];
|
||||
std::vector<int64_t> output_indices_shape = {2, x2_indices_shape_right};
|
||||
abstract::ShapePtr output_indices_shape_list =
|
||||
std::make_shared<abstract::Shape>(output_indices_shape, output_indices_shape, output_indices_shape);
|
||||
std::vector<int64_t> output_values_shape = {x2_indices_shape_right};
|
||||
abstract::ShapePtr output_values_shape_list =
|
||||
std::make_shared<abstract::Shape>(output_values_shape, output_values_shape, output_values_shape);
|
||||
auto input_shape = input_args[kInputIndex2]->BuildShape();
|
||||
abstract::ShapePtr output_shape_shape_list = input_shape->cast<abstract::ShapePtr>();
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
output_indices_shape_list, output_values_shape_list, output_shape_shape_list});
|
||||
} else {
|
||||
std::vector<int64_t> output_shape = {abstract::Shape::SHP_ANY};
|
||||
std::vector<int64_t> infer_shape_min = {0};
|
||||
std::vector<int64_t> infer_shape_max = {MAX_LEN};
|
||||
abstract::ShapePtr output_shape_list =
|
||||
std::make_shared<abstract::Shape>(output_shape, infer_shape_min, infer_shape_max);
|
||||
return std::make_shared<abstract::TupleShape>(
|
||||
std::vector<abstract::BaseShapePtr>{output_shape_list, output_shape_list, output_shape_list});
|
||||
}
|
||||
}
|
||||
|
||||
TuplePtr SspaddmmInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
std::map<std::string, TypePtr> x1_args = {{"x1_indices", input_args[kInputIndex0]->BuildType()},
|
||||
{"x1_shape", input_args[kInputIndex2]->BuildType()}};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(x1_args, {kInt32, kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x1_values", input_args[kInputIndex1]->BuildType(),
|
||||
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
|
||||
op_name);
|
||||
std::map<std::string, TypePtr> x2_args = {{"x2_indices", input_args[kInputIndex3]->BuildType()},
|
||||
{"x2_shape", input_args[kInputIndex5]->BuildType()}};
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(x2_args, {kInt32, kInt64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x2_values", input_args[kInputIndex4]->BuildType(),
|
||||
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x3_dense", input_args[kInputIndex6]->BuildType(),
|
||||
{kUInt8, kInt8, kInt16, kInt32, kInt64, kFloat32, kFloat64},
|
||||
op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(
|
||||
"alpha", input_args[kInputIndex7]->BuildType(),
|
||||
{kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid(
|
||||
"beta", input_args[kInputIndex8]->BuildType(),
|
||||
{kUInt8, kUInt16, kUInt32, kUInt64, kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64}, op_name);
|
||||
auto expect_dtype = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
auto beta_dtype = input_args[kInputIndex8]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
if (!(expect_dtype->type_id() == kNumberTypeFloat32 || expect_dtype->type_id() == kNumberTypeFloat64)) {
|
||||
auto beta_dtype_id = beta_dtype->type_id();
|
||||
if (beta_dtype_id == kNumberTypeFloat16 || beta_dtype_id == kNumberTypeFloat32 ||
|
||||
beta_dtype_id == kNumberTypeFloat64) {
|
||||
MS_EXCEPTION(TypeError) << "For " << op_name << ",beta dtype: " << beta_dtype->ToString()
|
||||
<< " can't convert to the desired output type: " << expect_dtype->ToString() << ".";
|
||||
}
|
||||
}
|
||||
std::map<std::string, TypePtr> args = {{"x1_values", input_args[kInputIndex1]->BuildType()},
|
||||
{"x2_values", input_args[kInputIndex4]->BuildType()},
|
||||
{"x3_dense", input_args[kInputIndex6]->BuildType()}};
|
||||
auto output_values_type = CheckAndConvertUtils::CheckTensorTypeSame(
|
||||
args, {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat32, kFloat64}, op_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{kInt64, output_values_type, kInt64});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 9;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
CheckInputTensor(input_args);
|
||||
auto infer_type = SspaddmmInferType(primitive, input_args);
|
||||
auto infer_shape = SspaddmmInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Sspaddmm, prim::kPrimSspaddmm, SspaddmmInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SSPADDMM_H_
|
||||
#define MINDSPORE_CORE_OPS_SSPADDMM_H_
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSspaddmm = "Sspaddmm";
|
||||
/// \brief Performs a matrix multiplication of the matrices mat1 and mat2.
|
||||
/// The matrix input is added to the final result.
|
||||
/// Refer to Python API @ref mindspore.ops.Sspaddmm for more details.
|
||||
class MIND_API Sspaddmm : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Sspaddmm);
|
||||
/// \brief Constructor.
|
||||
Sspaddmm() : BaseOperator(kNameSspaddmm) {
|
||||
InitIOName(
|
||||
{"x1_indices", "x1_values", "x1_shape", "x2_indices", "x2_values", "x2_shape", "x3_dense", "alpha", "beta"},
|
||||
{"y_indices", "y_values", "y_shape"});
|
||||
}
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Sspaddmm for the inputs.
|
||||
void Init() {}
|
||||
};
|
||||
abstract::AbstractBasePtr SspaddmmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SSPADDMM_H_
|
|
@ -209,3 +209,4 @@ from .scatter_nd_max import _scatter_nd_max_aicpu
|
|||
from .conj import _conj_aicpu
|
||||
from .scatter_nd_min import _scatter_nd_min_aicpu
|
||||
from .cholesky import _cholesky_aicpu
|
||||
from .sspaddmm import _sspaddmm_aicpu
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Sspaddmm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
|
||||
def _reg_aicpu():
|
||||
return AiCPURegOp("Sspaddmm") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "x1_indices", "required") \
|
||||
.input(1, "x1_values", "required") \
|
||||
.input(2, "x1_shape", "required") \
|
||||
.input(3, "x2_indices", "required") \
|
||||
.input(4, "x2_values", "required") \
|
||||
.input(5, "x2_shape", "required") \
|
||||
.input(6, "x3_dense", "required") \
|
||||
.input(7, "alpha", "required") \
|
||||
.input(8, "beta", "required") \
|
||||
.output(0, "y_indices", "required") \
|
||||
.output(1, "y_values", "required") \
|
||||
.output(2, "y_shape", "required")
|
||||
|
||||
|
||||
def _reg_format(op_info, dtype1, dtype2, dtype3, alpha, beta):
|
||||
return op_info.dtype_format(dtype1, dtype3, dtype1, dtype2, dtype3, dtype2, dtype3, alpha, beta,
|
||||
DataType.I64_Default, dtype3, DataType.I64_Default)
|
||||
|
||||
|
||||
def _reg_format_beta(op_info, dtype1, dtype2, alpha, beta):
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.U8_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I8_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I16_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I32_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.I64_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.F32_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, dtype1, dtype2, DataType.F64_Default, alpha, beta)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_alpha(op_info, dtype1, dtype2, alpha):
|
||||
"""alpha reg"""
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U8_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U16_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U32_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.U64_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I8_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I16_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I32_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.I64_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F16_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F32_Default)
|
||||
op_info = _reg_format_beta(op_info, dtype1, dtype2, alpha, DataType.F64_Default)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_indices(op_info, dtype1, dtype2):
|
||||
"""indices reg"""
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U8_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U16_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U32_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.U64_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I8_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I16_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I32_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.I64_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F16_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F32_Default)
|
||||
op_info = _reg_format_alpha(op_info, dtype1, dtype2, DataType.F64_Default)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_indices_shape(op_info):
|
||||
"""shape reg"""
|
||||
op_info = _reg_format_indices(op_info, DataType.I32_Default, DataType.I32_Default)
|
||||
return op_info.get_op_info()
|
||||
|
||||
|
||||
sspaddmm_op_info = _reg_format_indices_shape(_reg_aicpu())
|
||||
|
||||
|
||||
@op_info_register(sspaddmm_op_info)
|
||||
def _sspaddmm_aicpu():
|
||||
"""Sspaddmm AiCPU register"""
|
||||
return
|
|
@ -78,3 +78,4 @@ from .buffer_sample import _buffer_sample_cpu
|
|||
from .priority_replay_buffer import _prb_push_op_cpu
|
||||
from .priority_replay_buffer import _prb_sample_op_cpu
|
||||
from .space_to_batch_nd import _space_to_batch_nd_cpu
|
||||
from .sspaddmm import _sspaddmm_cpu
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Sspaddmm op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
sspaddmm_op_info = CpuRegOp("Sspaddmm") \
|
||||
.input(0, "x1_indices", "required") \
|
||||
.input(1, "x1_values", "required") \
|
||||
.input(2, "x1_shape", "required") \
|
||||
.input(3, "x2_indices", "required") \
|
||||
.input(4, "x2_values", "required") \
|
||||
.input(5, "x2_shape", "required") \
|
||||
.input(6, "x3_dense", "required") \
|
||||
.input(7, "alpha", "required") \
|
||||
.input(8, "beta", "required") \
|
||||
.output(0, "y_indices", "required") \
|
||||
.output(1, "y_values", "required") \
|
||||
.output(2, "y_shape", "required")
|
||||
|
||||
|
||||
def _reg_format(op_info, dtype, alpha, beta):
|
||||
return op_info.dtype_format(DataType.I32_Default, dtype, DataType.I32_Default, DataType.I32_Default, dtype,
|
||||
DataType.I32_Default, dtype, alpha, beta, DataType.I64_Default, dtype,
|
||||
DataType.I64_Default)
|
||||
|
||||
|
||||
def _reg_format_beta(op_info, alpha, beta):
|
||||
op_info = _reg_format(op_info, DataType.U8_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.I8_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.I16_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.I32_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.I64_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.F32_Default, alpha, beta)
|
||||
op_info = _reg_format(op_info, DataType.F64_Default, alpha, beta)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_alpha(op_info, alpha):
|
||||
"""alpha reg"""
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.U8_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.U16_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.U32_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.U64_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.I8_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.I16_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.I32_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.I64_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.F16_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.F32_Default)
|
||||
op_info = _reg_format_beta(op_info, alpha, DataType.F64_Default)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_indices(op_info):
|
||||
"""indices reg"""
|
||||
op_info = _reg_format_alpha(op_info, DataType.U8_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.U16_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.U32_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.U64_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.I8_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.I16_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.I32_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.I64_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.F16_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.F32_Default)
|
||||
op_info = _reg_format_alpha(op_info, DataType.F64_Default)
|
||||
return op_info
|
||||
|
||||
|
||||
def _reg_format_indices_shape(op_info):
|
||||
"""shape reg"""
|
||||
op_info = _reg_format_indices(op_info)
|
||||
return op_info.get_op_info()
|
||||
|
||||
|
||||
sspaddmm_op_info_all = _reg_format_indices_shape(sspaddmm_op_info)
|
||||
|
||||
|
||||
@op_info_register(sspaddmm_op_info_all)
|
||||
def _sspaddmm_cpu():
|
||||
"""Sspaddmm cpu register"""
|
||||
return
|
|
@ -379,3 +379,103 @@ class DenseToDenseSetOperation(Primitive):
|
|||
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y_indices', 'y_values', 'y_shape'])
|
||||
validator.check_value_type("set_operation", set_operation, [str], self.name)
|
||||
validator.check_value_type("validate_indices", validate_indices, [bool], self.name)
|
||||
|
||||
|
||||
class Sspaddmm(Primitive):
|
||||
r"""
|
||||
Matrix multiplies a sparse tensor `x2` with a dense tensor `x3`, then adds the sparse tensor `x1`.
|
||||
If `x1_shape` is :math:`(s0, s1)`, `x2_shpae` should be :math:`(s0, s2)`, the `x3_shape` should be :math:`(s2, s1)`.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
.. math::
|
||||
out =\beta * x1 + \alpha * (x2 @ x3),
|
||||
|
||||
Inputs:
|
||||
- **x1_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
Support int32, int64. The shape is :math:`(2, n)`. If `x1_shape` is :math:`(s0, s1)`, the row index
|
||||
value of `x1_indices` should be a non-negative and less than `s0` int number, the col index value of
|
||||
`x1_indices` should be a non-negative and less than `s1` int number.
|
||||
- **x1_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in
|
||||
the `x1_indices`. Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as
|
||||
`x2_values` and `x3_dense`. The shape should be :math:`(n,)`.
|
||||
- **x1_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32, int64,
|
||||
have 2 positive int elements, shape is :math:`(2,)`. The dtype should be the same as `x1_indices`.
|
||||
- **x2_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
Support int32, int64. The shape is :math:`(2, n)`. If `x2_shape` is :math:`(s0, s2)`, the row index
|
||||
value of `x2_indices` should be a non-negative and less than `s0` int number, the col index value of
|
||||
`x2_indices` should be a non-negative and less than `s2` int number.
|
||||
- **x2_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `x2_indices`.
|
||||
Support float32, float64, int8, int16, int32, int64, uint8. The dtype should be the same as `x1_values`
|
||||
and `x3_dense`. The shape should be :math:`(n,)`.
|
||||
- **x2_shape** (Tensor) - A 1-D Tensor, specifies the shape of sparse tensor. Support int32,int64,
|
||||
have 2 positive int elements, shape is :math:`(2,)`. The dtype is same as `x2_indices`.
|
||||
- **x3_dense** (Tensor) - A 2-D Tensor, the dtype should be the same as `x2_values` and `x3_dense`.
|
||||
- **alpha** (Tensor) - A 0-D or 1-D Tensor, the weight of x1. If alpha is 1-D tensor,
|
||||
the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64,
|
||||
int8, int16, int32, int64, float16, float32, float64. If the dtype of alpha is not the same with expected
|
||||
output dtype, alpha value should be convert without overflow.
|
||||
- **beta** (Tensor) - A 0-D or 1-D, the weight of x2@x3. If alpha is 1-D tensor,
|
||||
the shape should be :math:`()` otherwise the shape is :math:`(1,)`. Support uint8, uint16, uint32, uint64,
|
||||
int8, int16, int32, int64, float16, float32, float64. If the `x1_values` dtype is byte, char, short, int,
|
||||
long, the dtype of beta doesn't support float16, float32, float64.
|
||||
|
||||
Outputs:
|
||||
- **y_indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
|
||||
The dtype is int64, each element value should be a non-negative int number. The shape is :math:`(2, n)`.
|
||||
- **y_values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `y_indices`.
|
||||
The dtype is the same as `x1_values` . The shape should be :math:`(n,)`.
|
||||
- **y_shape** (Tensor) - A 1-D Tensor, A positive int tuple which specifies the shape of sparse tensor.
|
||||
The dtype is int64, the values is the same as `x1_shape`.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x1_indices`, `x1_shape` is not the same and neither int32 nor int64.
|
||||
TypeError: If dtype of `x2_indices`, `x2_shape` is not the same and not int32 or int64.
|
||||
TypeError: If type of `x1_values`, `x2_values`, `x3_dense` is not the same.
|
||||
TypeError: If dtype of `x1_values`, `x2_values`, `x3_dense` is not uint8, int8, int16, int32, int64, float32,
|
||||
float64.
|
||||
ValueError: If shape of `x1_indices`, `x2_indices` is not (2, n).
|
||||
ValueError: If shape of `x1_values`, `x2_values` is not (n,).
|
||||
ValueError: If dim0 size of `x1_values` is not the same with dim1 size of `x1_indices`.
|
||||
ValueError: If dim0 size of `x2_values` is not the same with dim1 size of `x2_indices`.
|
||||
ValueError: If shape of `x1_shape` or shape of `x2_shape` is not (2,).
|
||||
ValueError: If dim of `x3_dense` is not 2D.
|
||||
ValueError: If dtype of `alpha` is not the same with `x2_values` dtype, and alpha value convert to the
|
||||
`x2_values` dtype overflow.
|
||||
TypeError: If dtype of `alpha`, `beta` is not uint8, uint16, uint32, uint64, int8, int16, int32, int64,
|
||||
float16, float32, float64.
|
||||
TypeError: If the `x1_values` dtype is byte, char, short, int, long, while the dtype of beta is float16,
|
||||
float32 or float64.
|
||||
ValueError: If the shape of `alpha`, `beta` is not () or (1,).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1_indices = Tensor(np.array([[0, 1], [0, 1]]), mstype.int64)
|
||||
>>> x1_values = Tensor(np.array([1, 2]), mstype.int32)
|
||||
>>> x1_shape = Tensor(np.array([3, 3]), mstype.int64)
|
||||
>>> x2_indices = Tensor(np.array([[0, 1], [2, 2]]), mstype.int64)
|
||||
>>> x2_values = Tensor(np.array([3, 4]), mstype.int32)
|
||||
>>> x2_shape = Tensor(np.array([3, 3]), mstype.int64)
|
||||
>>> x3_dense = Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]), mstype.int32)
|
||||
>>> alpha = Tensor(np.array(1), mstype.int32)
|
||||
>>> beta = Tensor(np.array(1), mstype.int32)
|
||||
>>> sspaddmm = ops.Sspaddmm()
|
||||
>>> out_indices, out_values, out_shapes = sspaddmm(x1_indices, x1_values, x1_shape,
|
||||
... x2_indices, x2_values, x2_shape, x3_dense, alpha, beta)
|
||||
>>> print(out_indices)
|
||||
[[0 1 0 0 0 1 1 1]
|
||||
[0 1 0 1 2 0 1 2]]
|
||||
>>> print(out_values)
|
||||
[ 1 2 9 6 3 12 8 4]
|
||||
>>> print(out_shapes)
|
||||
[3 3]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Sspaddmm."""
|
||||
self.init_prim_io_names(inputs=['x1_indices', 'x1_values', 'x1_shape', 'x2_indices', 'x2_values', 'x2_shape',
|
||||
'x3_dense', 'alpha', 'beta'], outputs=['y_indices', 'y_values', 'y_shape'])
|
||||
|
|
|
@ -225,6 +225,11 @@ class InputOpNet(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
def construct9_c0(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
|
||||
x = self.op(x1, x2, x3, x4, x5, x6, x7, x8, x9)
|
||||
return x
|
||||
|
||||
|
||||
def gen_net(op, input_num, training=True, desc_const=(), const_first=False, add_fake_input=False):
|
||||
if isinstance(op, nn.Cell):
|
||||
return op
|
||||
|
|
|
@ -73,7 +73,7 @@ from mindspore.ops.operations.nn_ops import MaxPoolV1
|
|||
from mindspore.ops.operations.array_ops import NonZero
|
||||
from mindspore.ops.operations._grad_ops import MaxPoolGradV1
|
||||
from mindspore.ops.operations.nn_ops import ReLUV3
|
||||
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix
|
||||
from mindspore.ops.operations.sparse_ops import DenseToCSRSparseMatrix, Sspaddmm
|
||||
from mindspore.ops.operations.other_ops import BlackmanWindow
|
||||
from mindspore.nn.layer import normalization
|
||||
from mindspore.ops.operations.array_ops import RightShift
|
||||
|
@ -1605,6 +1605,18 @@ test_case_math_ops = [
|
|||
'desc_inputs': [Tensor(np.array([[1, 0], [0, 1]]).astype(np.float32)),
|
||||
Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('Sspaddmm', {
|
||||
'block': Sspaddmm(),
|
||||
'desc_inputs': [Tensor(np.array([[0, 1], [0, 1]]).astype(np.int64)),
|
||||
Tensor(np.array([1, 2]).astype(np.int32)),
|
||||
Tensor(np.array([3, 3]).astype(np.int64)),
|
||||
Tensor(np.array([[0, 1], [2, 2]]).astype(np.int64)),
|
||||
Tensor(np.array([3, 4]).astype(np.int32)),
|
||||
Tensor(np.array([3, 3]).astype(np.int64)),
|
||||
Tensor(np.array([[1, 2, 3], [1, 3, 2], [3, 2, 1]]).astype(np.int32)),
|
||||
Tensor(np.array(1).astype(np.int32)),
|
||||
Tensor(np.array(1).astype(np.int32))],
|
||||
'skip': ['backward']}),
|
||||
('Embedding_1', {
|
||||
'block': Embedding(vocab_size=10, embedding_size=3),
|
||||
'desc_inputs': [Tensor(np.array([0, 2, 2, 7]).astype(np.int32))],
|
||||
|
|
Loading…
Reference in New Issue