clean code

This commit is contained in:
hujiahui8 2022-08-03 16:18:57 +08:00
parent 2b35007b0c
commit 886610fd80
4 changed files with 116 additions and 92 deletions

View File

@ -27,15 +27,26 @@ constexpr size_t kInputsNum = 9;
constexpr size_t kOutputsNum = 3;
constexpr char kKernelName[] = "Sspaddmm";
#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \
if (dtype1 == kNumberTypeInt32) { \
CheckSparseIndicesLegal<int32_t, int32_t>(indices, shapes, num, x_name); \
} else { \
CheckSparseIndicesLegal<int64_t, int64_t>(indices, shapes, num, x_name); \
}
constexpr size_t kOutputShapeIndex = 2;
constexpr size_t kInputShapeIndex = 2;
constexpr size_t kMat1IndiceIndex = 3;
constexpr size_t kMat1ValueIndex = 4;
constexpr size_t kMat1ShapeIndex = 5;
constexpr size_t kMat2Index = 6;
constexpr size_t kAlphaIndex = 7;
constexpr size_t kBetaIndex = 8;
} // namespace
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
void SspaddmmCPUKernelMod::CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr,
size_t num, const std::string &x_name) const {
if (indices_dtype == kNumberTypeInt32) {
CheckSparseIndicesLegal<int32_t, int32_t>(indices_addr, shape_addr, num, x_name);
} else {
CheckSparseIndicesLegal<int64_t, int64_t>(indices_addr, shape_addr, num, x_name);
}
}
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) const {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
@ -59,14 +70,14 @@ void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6);
input_values_num_ = input_indices_shape[1];
mat1_values_num_ = mat1_indices_shape[1];
y_values_num_ = y_indices_shape[1];
mat2_row_ = mat2_shape[0];
mat2_col_ = mat2_shape[1];
input_values_num_ = LongToSize(input_indices_shape[1]);
mat1_values_num_ = LongToSize(mat1_indices_shape[1]);
y_values_num_ = LongToSize(y_indices_shape[1]);
mat2_row_ = LongToSize(mat2_shape[0]);
mat2_col_ = LongToSize(mat2_shape[1]);
}
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
switch (output_values_dtype_) {
case kNumberTypeUInt8: {
@ -107,38 +118,39 @@ bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const s
template <typename T>
void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input_indices_addr = inputs[0]->addr;
auto input_values_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto input_shape_addr = inputs[2]->addr;
auto mat1_indices_addr = inputs[3]->addr;
auto mat1_values_addr = reinterpret_cast<T *>(inputs[4]->addr);
auto mat1_shape_addr = inputs[5]->addr;
auto mat2_addr = reinterpret_cast<T *>(inputs[6]->addr);
auto alpha_val_addr = inputs[7]->addr;
auto beta_val_addr = inputs[8]->addr;
auto y_indices_addr = reinterpret_cast<int64_t *>(outputs[0]->addr);
auto y_values_addr = reinterpret_cast<T *>(outputs[1]->addr);
auto y_shape_addr = reinterpret_cast<int64_t *>(outputs[2]->addr);
CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_,
"x1");
CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_,
"x2");
auto input_values_addr = GetDeviceAddress<T>(inputs, 1);
auto input_shape_addr = inputs[kInputShapeIndex]->addr;
auto mat1_indices_addr = inputs[kMat1IndiceIndex]->addr;
auto mat1_values_addr = GetDeviceAddress<T>(inputs, kMat1ValueIndex);
auto mat1_shape_addr = inputs[kMat1ShapeIndex]->addr;
auto mat2_addr = GetDeviceAddress<T>(inputs, kMat2Index);
auto alpha_val_addr = inputs[kAlphaIndex]->addr;
auto beta_val_addr = inputs[kBetaIndex]->addr;
auto y_indices_addr = GetDeviceAddress<int64_t>(outputs, 0);
auto y_values_addr = GetDeviceAddress<T>(outputs, 1);
auto y_shape_addr = GetDeviceAddress<int64_t>(outputs, kOutputShapeIndex);
const std::string x1_name = "x1";
const std::string x2_name = "x2";
CheckSparseIndices(input_indices_dtype_, input_indices_addr, input_shape_addr, input_values_num_, x1_name);
CheckSparseIndices(mat1_indices_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_, x2_name);
int64_t mat1_row, mat1_col, input_row, input_col;
if (mat1_shape_dtype_ == kNumberTypeInt32) {
auto mat1_shape_val = reinterpret_cast<int32_t *>(mat1_shape_addr);
auto mat1_shape_val = static_cast<int32_t *>(mat1_shape_addr);
mat1_row = static_cast<int64_t>(mat1_shape_val[0]);
mat1_col = static_cast<int64_t>(mat1_shape_val[1]);
} else {
auto mat1_shape_val = reinterpret_cast<int64_t *>(mat1_shape_addr);
auto mat1_shape_val = static_cast<int64_t *>(mat1_shape_addr);
mat1_row = mat1_shape_val[0];
mat1_col = mat1_shape_val[1];
}
if (input_shape_dtype_ == kNumberTypeInt32) {
auto input_shape_val = reinterpret_cast<int32_t *>(input_shape_addr);
auto input_shape_val = static_cast<int32_t *>(input_shape_addr);
input_row = static_cast<int64_t>(input_shape_val[0]);
input_col = static_cast<int64_t>(input_shape_val[1]);
} else {
auto input_shape_val = reinterpret_cast<int64_t *>(input_shape_addr);
auto input_shape_val = static_cast<int64_t *>(input_shape_addr);
input_row = input_shape_val[0];
input_col = input_shape_val[1];
}
@ -173,22 +185,21 @@ void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, c
SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
y_values_addr, y_values_num_);
}
auto row = mat1_col;
auto col = input_col;
if (mat1_indices_dtype_ == kNumberTypeInt32) {
SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col);
y_values_addr, y_values_num_, col);
} else {
SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col);
y_values_addr, y_values_num_, col);
}
}
template <typename T, typename S>
void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num,
std::string x_name) {
auto indices_val = reinterpret_cast<T *>(indices_addr);
auto shape_val = reinterpret_cast<S *>(shape_addr);
const std::string &x_name) const {
auto indices_val = static_cast<T *>(indices_addr);
auto shape_val = static_cast<S *>(shape_addr);
int shape_num = 2;
for (int i = 0; i < shape_num; i++) {
if (shape_val[i] <= 0) {
@ -210,8 +221,8 @@ void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *sha
}
template <typename T>
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) {
auto input_shape_val = reinterpret_cast<T *>(input_shape);
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) const {
auto input_shape_val = static_cast<T *>(input_shape);
size_t shape_num = 2;
for (size_t i = 0; i < shape_num; i++) {
y_shape[i] = static_cast<int64_t>(input_shape_val[i]);
@ -231,7 +242,7 @@ void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) {
// scalar * sparse matrix for beta * input alpha * mat1
template <typename T>
T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) {
T *SspaddmmCPUKernelMod::ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid) {
T val;
if (!(data_num > 0)) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. ";
@ -239,50 +250,49 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
T *sparse_val_bak = new T[data_num];
switch (tid) {
case kNumberTypeUInt8:
val = static_cast<T>(reinterpret_cast<uint8_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<uint8_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt16:
val = static_cast<T>(reinterpret_cast<uint16_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<uint16_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt32:
val = static_cast<T>(reinterpret_cast<uint32_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<uint32_t *>(scalar_val)[0]);
break;
case kNumberTypeUInt64:
val = static_cast<T>(reinterpret_cast<uint64_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<uint64_t *>(scalar_val)[0]);
break;
case kNumberTypeInt8:
val = static_cast<T>(reinterpret_cast<int8_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<int8_t *>(scalar_val)[0]);
break;
case kNumberTypeInt16:
val = static_cast<T>(reinterpret_cast<int16_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<int16_t *>(scalar_val)[0]);
break;
case kNumberTypeInt32:
val = static_cast<T>(reinterpret_cast<int32_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<int32_t *>(scalar_val)[0]);
break;
case kNumberTypeInt64:
val = static_cast<T>(reinterpret_cast<int64_t *>(scalar_val)[0]);
val = static_cast<T>(static_cast<int64_t *>(scalar_val)[0]);
break;
case kNumberTypeFloat16:
val = static_cast<T>(reinterpret_cast<float16 *>(scalar_val)[0]);
val = static_cast<T>(static_cast<float16 *>(scalar_val)[0]);
break;
case kNumberTypeFloat32:
val = static_cast<T>(reinterpret_cast<float *>(scalar_val)[0]);
val = static_cast<T>(static_cast<float *>(scalar_val)[0]);
break;
case kNumberTypeFloat64:
val = static_cast<T>(reinterpret_cast<double *>(scalar_val)[0]);
val = static_cast<T>(static_cast<double *>(scalar_val)[0]);
break;
case kNumberTypeBool:
val = static_cast<T>(reinterpret_cast<bool *>(scalar_val)[0]);
val = static_cast<T>(static_cast<bool *>(scalar_val)[0]);
break;
case kNumberTypeComplex64:
val = static_cast<T>(reinterpret_cast<std::complex<float> *>(scalar_val)[0].real());
val = static_cast<T>(static_cast<std::complex<float> *>(scalar_val)[0].real());
break;
case kNumberTypeComplex128:
val = static_cast<T>(reinterpret_cast<std::complex<double> *>(scalar_val)[0].real());
val = static_cast<T>(static_cast<std::complex<double> *>(scalar_val)[0].real());
break;
default:
MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. ";
break;
}
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
@ -296,10 +306,10 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
// sparse matrix add sparse matrix
// input + mat1 @ mat2
template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices,
S *y_values, size_t y_num) {
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, const S *input_values, size_t input_num,
int64_t *y_indices, S *y_values, size_t y_num) {
// to implement m1[row][col] = vals
auto input_ids = reinterpret_cast<T *>(input_indices);
auto input_ids = static_cast<T *>(input_indices);
this->cnt_ = input_num;
// get output vals and index addr
auto task = [&](size_t start, size_t end) {
@ -315,11 +325,11 @@ void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values,
}
template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row,
int64_t mat2_col_) {
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num,
const S *mat2_addr, int64_t *y_indices, S *y_values, size_t y_vals_num,
int64_t mat2_col) {
// the result of mat1 @ mat2 will write to output directly
auto mat1_ids = reinterpret_cast<T *>(mat1_indices);
auto mat1_ids = static_cast<T *>(mat1_indices);
std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt;
std::unordered_map<T, std::vector<T>> unrepeated;
@ -332,16 +342,16 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
T _col = mat1_ids[i + mat1_vals_num];
unrepeated[_row].push_back(_col);
co_map_idx[_row][_col].push_back(mat1_values[i]);
for (int64_t j = 0; j < mat2_col_; j++) {
for (int64_t j = 0; j < mat2_col; j++) {
if (idx_map_cnt[_row][j] == 0) {
idx_map_cnt[_row][j] = this->cnt_;
idx_map_cnt[_row][j] = SizeToUint(this->cnt_);
this->cnt_++;
}
}
}
std::vector<T> res;
for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) {
for (auto it = unrepeated.begin(); it != unrepeated.end(); ++it) {
res.push_back(it->first);
}
@ -353,10 +363,11 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
for (auto row_mat2 : unrepeated[row_mat1]) {
S val = co_map_idx[row_mat1][row_mat2].back();
co_map_idx[row_mat1][row_mat2].pop_back();
for (int64_t j = 0; j < mat2_col_; j++) {
for (int64_t j = 0; j < mat2_col; j++) {
// get val
T idx = idx_map_cnt[row_mat1][j];
*(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j];
size_t idx = static_cast<size_t>(idx_map_cnt[row_mat1][j]);
int64_t row_mat2_long = static_cast<int64_t>(row_mat2);
*(y_values + idx) += val * mat2_addr[row_mat2_long * mat2_col + j];
y_indices[idx] = static_cast<int64_t>(row_mat1);
y_indices[idx + y_vals_num] = j;
}

View File

@ -30,28 +30,30 @@ class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod {
SspaddmmCPUKernelMod() = default;
~SspaddmmCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void CheckParam(const CNodePtr &kernel_node);
void CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr, size_t num,
const std::string &x_name) const;
void CheckParam(const CNodePtr &kernel_node) const;
template <typename T, typename S>
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name);
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, const std::string &x_name) const;
template <typename T>
void InitShape(void *input_shape, int64_t *y_shape);
void InitShape(void *input_shape, int64_t *y_shape) const;
template <typename T>
void ClearSparseValues(T *sparse_val, size_t data_num);
template <typename T>
T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid);
T *ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid);
template <typename T, typename S>
void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
void SparseAddSparse(void *input_indices, const S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
size_t y_num);
template <typename T, typename S>
void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col);
void SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num, const S *mat2_addr,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t mat2_col);
TypeId output_values_dtype_{kTypeUnknown};
TypeId input_indices_dtype_{kTypeUnknown};

View File

@ -724,25 +724,35 @@ class EinsumHelper {
}
ShapeVector temp_vec(input_shapes[idx_input].begin() + idx_left,
input_shapes[idx_input].begin() + idx_shape_right + 1);
if (element_shape_map_.find(ELL_VAL) != element_shape_map_.end()) {
if (element_shape_map_[ELL_VAL] != temp_vec) {
if (temp_vec.size() == 0 || element_shape_map_[ELL_VAL].size() == 0) {
element_shape_map_[ELL_VAL] = temp_vec.size() == 0 ? element_shape_map_[ELL_VAL] : temp_vec;
} else {
MS_LOG(ERROR)
<< "For " << node_name_
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
return false;
}
}
} else {
element_shape_map_[ELL_VAL] = temp_vec;
if (!AdjustElementMapShape(temp_vec)) {
return false;
}
}
}
return true;
}
bool AdjustElementMapShape(const ShapeVector &temp_vec) {
const size_t ellipsis_val_num = 52;
auto iter = element_shape_map_.find(ellipsis_val_num);
if (iter != element_shape_map_.end()) {
ShapeVector cur_vec = iter->second;
if (cur_vec != temp_vec) {
if (temp_vec.empty() || cur_vec.empty()) {
element_shape_map_[ellipsis_val_num] = temp_vec.empty() ? cur_vec : temp_vec;
} else {
MS_LOG(ERROR)
<< "For " << node_name_
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
return false;
}
}
} else {
element_shape_map_[ellipsis_val_num] = temp_vec;
}
return true;
}
void CalAxisShape(const std::vector<size_t> &axis_val, const ShapeVector &shape_val, size_t *idx,
ShapeVector *re_shape, std::vector<size_t> *res_trans_axis) {
for (auto val : axis_val) {

View File

@ -90,7 +90,8 @@ std::vector<int64_t> GetOutputShape(const PrimitivePtr &primitive, const std::ve
}
} else {
out_d = DoubleToLong(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1));
out_h = DoubleToLong(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1));
out_h =
DoubleToLong(std::floor((in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h) / stride_h + 1));
out_w =
DoubleToLong(std::floor((in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1));
}