Merge pull request !39597 from hujiahui8/clean_code
This commit is contained in:
i-robot 2022-08-08 09:17:07 +00:00 committed by Gitee
commit d3eb1586b9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 116 additions and 92 deletions

View File

@ -27,15 +27,26 @@ constexpr size_t kInputsNum = 9;
constexpr size_t kOutputsNum = 3; constexpr size_t kOutputsNum = 3;
constexpr char kKernelName[] = "Sspaddmm"; constexpr char kKernelName[] = "Sspaddmm";
#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \ constexpr size_t kOutputShapeIndex = 2;
if (dtype1 == kNumberTypeInt32) { \ constexpr size_t kInputShapeIndex = 2;
CheckSparseIndicesLegal<int32_t, int32_t>(indices, shapes, num, x_name); \ constexpr size_t kMat1IndiceIndex = 3;
} else { \ constexpr size_t kMat1ValueIndex = 4;
CheckSparseIndicesLegal<int64_t, int64_t>(indices, shapes, num, x_name); \ constexpr size_t kMat1ShapeIndex = 5;
} constexpr size_t kMat2Index = 6;
constexpr size_t kAlphaIndex = 7;
constexpr size_t kBetaIndex = 8;
} // namespace } // namespace
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) { void SspaddmmCPUKernelMod::CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr,
size_t num, const std::string &x_name) const {
if (indices_dtype == kNumberTypeInt32) {
CheckSparseIndicesLegal<int32_t, int32_t>(indices_addr, shape_addr, num, x_name);
} else {
CheckSparseIndicesLegal<int64_t, int64_t>(indices_addr, shape_addr, num, x_name);
}
}
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) const {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName); CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
@ -59,14 +70,14 @@ void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0); auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6); auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6);
input_values_num_ = input_indices_shape[1]; input_values_num_ = LongToSize(input_indices_shape[1]);
mat1_values_num_ = mat1_indices_shape[1]; mat1_values_num_ = LongToSize(mat1_indices_shape[1]);
y_values_num_ = y_indices_shape[1]; y_values_num_ = LongToSize(y_indices_shape[1]);
mat2_row_ = mat2_shape[0]; mat2_row_ = LongToSize(mat2_shape[0]);
mat2_col_ = mat2_shape[1]; mat2_col_ = LongToSize(mat2_shape[1]);
} }
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
switch (output_values_dtype_) { switch (output_values_dtype_) {
case kNumberTypeUInt8: { case kNumberTypeUInt8: {
@ -107,38 +118,39 @@ bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const s
template <typename T> template <typename T>
void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto input_indices_addr = inputs[0]->addr; auto input_indices_addr = inputs[0]->addr;
auto input_values_addr = reinterpret_cast<T *>(inputs[1]->addr); auto input_values_addr = GetDeviceAddress<T>(inputs, 1);
auto input_shape_addr = inputs[2]->addr; auto input_shape_addr = inputs[kInputShapeIndex]->addr;
auto mat1_indices_addr = inputs[3]->addr; auto mat1_indices_addr = inputs[kMat1IndiceIndex]->addr;
auto mat1_values_addr = reinterpret_cast<T *>(inputs[4]->addr); auto mat1_values_addr = GetDeviceAddress<T>(inputs, kMat1ValueIndex);
auto mat1_shape_addr = inputs[5]->addr; auto mat1_shape_addr = inputs[kMat1ShapeIndex]->addr;
auto mat2_addr = reinterpret_cast<T *>(inputs[6]->addr); auto mat2_addr = GetDeviceAddress<T>(inputs, kMat2Index);
auto alpha_val_addr = inputs[7]->addr; auto alpha_val_addr = inputs[kAlphaIndex]->addr;
auto beta_val_addr = inputs[8]->addr; auto beta_val_addr = inputs[kBetaIndex]->addr;
auto y_indices_addr = reinterpret_cast<int64_t *>(outputs[0]->addr); auto y_indices_addr = GetDeviceAddress<int64_t>(outputs, 0);
auto y_values_addr = reinterpret_cast<T *>(outputs[1]->addr); auto y_values_addr = GetDeviceAddress<T>(outputs, 1);
auto y_shape_addr = reinterpret_cast<int64_t *>(outputs[2]->addr); auto y_shape_addr = GetDeviceAddress<int64_t>(outputs, kOutputShapeIndex);
CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_,
"x1"); const std::string x1_name = "x1";
CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_, const std::string x2_name = "x2";
"x2"); CheckSparseIndices(input_indices_dtype_, input_indices_addr, input_shape_addr, input_values_num_, x1_name);
CheckSparseIndices(mat1_indices_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_, x2_name);
int64_t mat1_row, mat1_col, input_row, input_col; int64_t mat1_row, mat1_col, input_row, input_col;
if (mat1_shape_dtype_ == kNumberTypeInt32) { if (mat1_shape_dtype_ == kNumberTypeInt32) {
auto mat1_shape_val = reinterpret_cast<int32_t *>(mat1_shape_addr); auto mat1_shape_val = static_cast<int32_t *>(mat1_shape_addr);
mat1_row = static_cast<int64_t>(mat1_shape_val[0]); mat1_row = static_cast<int64_t>(mat1_shape_val[0]);
mat1_col = static_cast<int64_t>(mat1_shape_val[1]); mat1_col = static_cast<int64_t>(mat1_shape_val[1]);
} else { } else {
auto mat1_shape_val = reinterpret_cast<int64_t *>(mat1_shape_addr); auto mat1_shape_val = static_cast<int64_t *>(mat1_shape_addr);
mat1_row = mat1_shape_val[0]; mat1_row = mat1_shape_val[0];
mat1_col = mat1_shape_val[1]; mat1_col = mat1_shape_val[1];
} }
if (input_shape_dtype_ == kNumberTypeInt32) { if (input_shape_dtype_ == kNumberTypeInt32) {
auto input_shape_val = reinterpret_cast<int32_t *>(input_shape_addr); auto input_shape_val = static_cast<int32_t *>(input_shape_addr);
input_row = static_cast<int64_t>(input_shape_val[0]); input_row = static_cast<int64_t>(input_shape_val[0]);
input_col = static_cast<int64_t>(input_shape_val[1]); input_col = static_cast<int64_t>(input_shape_val[1]);
} else { } else {
auto input_shape_val = reinterpret_cast<int64_t *>(input_shape_addr); auto input_shape_val = static_cast<int64_t *>(input_shape_addr);
input_row = input_shape_val[0]; input_row = input_shape_val[0];
input_col = input_shape_val[1]; input_col = input_shape_val[1];
} }
@ -173,22 +185,21 @@ void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, c
SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr, SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
y_values_addr, y_values_num_); y_values_addr, y_values_num_);
} }
auto row = mat1_col;
auto col = input_col; auto col = input_col;
if (mat1_indices_dtype_ == kNumberTypeInt32) { if (mat1_indices_dtype_ == kNumberTypeInt32) {
SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr, SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col); y_values_addr, y_values_num_, col);
} else { } else {
SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr, SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
y_values_addr, y_values_num_, row, col); y_values_addr, y_values_num_, col);
} }
} }
template <typename T, typename S> template <typename T, typename S>
void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num,
std::string x_name) { const std::string &x_name) const {
auto indices_val = reinterpret_cast<T *>(indices_addr); auto indices_val = static_cast<T *>(indices_addr);
auto shape_val = reinterpret_cast<S *>(shape_addr); auto shape_val = static_cast<S *>(shape_addr);
int shape_num = 2; int shape_num = 2;
for (int i = 0; i < shape_num; i++) { for (int i = 0; i < shape_num; i++) {
if (shape_val[i] <= 0) { if (shape_val[i] <= 0) {
@ -210,8 +221,8 @@ void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *sha
} }
template <typename T> template <typename T>
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) { void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) const {
auto input_shape_val = reinterpret_cast<T *>(input_shape); auto input_shape_val = static_cast<T *>(input_shape);
size_t shape_num = 2; size_t shape_num = 2;
for (size_t i = 0; i < shape_num; i++) { for (size_t i = 0; i < shape_num; i++) {
y_shape[i] = static_cast<int64_t>(input_shape_val[i]); y_shape[i] = static_cast<int64_t>(input_shape_val[i]);
@ -231,7 +242,7 @@ void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) {
// scalar * sparse matrix for beta * input alpha * mat1 // scalar * sparse matrix for beta * input alpha * mat1
template <typename T> template <typename T>
T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) { T *SspaddmmCPUKernelMod::ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid) {
T val; T val;
if (!(data_num > 0)) { if (!(data_num > 0)) {
MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. "; MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. ";
@ -239,50 +250,49 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
T *sparse_val_bak = new T[data_num]; T *sparse_val_bak = new T[data_num];
switch (tid) { switch (tid) {
case kNumberTypeUInt8: case kNumberTypeUInt8:
val = static_cast<T>(reinterpret_cast<uint8_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<uint8_t *>(scalar_val)[0]);
break; break;
case kNumberTypeUInt16: case kNumberTypeUInt16:
val = static_cast<T>(reinterpret_cast<uint16_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<uint16_t *>(scalar_val)[0]);
break; break;
case kNumberTypeUInt32: case kNumberTypeUInt32:
val = static_cast<T>(reinterpret_cast<uint32_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<uint32_t *>(scalar_val)[0]);
break; break;
case kNumberTypeUInt64: case kNumberTypeUInt64:
val = static_cast<T>(reinterpret_cast<uint64_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<uint64_t *>(scalar_val)[0]);
break; break;
case kNumberTypeInt8: case kNumberTypeInt8:
val = static_cast<T>(reinterpret_cast<int8_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<int8_t *>(scalar_val)[0]);
break; break;
case kNumberTypeInt16: case kNumberTypeInt16:
val = static_cast<T>(reinterpret_cast<int16_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<int16_t *>(scalar_val)[0]);
break; break;
case kNumberTypeInt32: case kNumberTypeInt32:
val = static_cast<T>(reinterpret_cast<int32_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<int32_t *>(scalar_val)[0]);
break; break;
case kNumberTypeInt64: case kNumberTypeInt64:
val = static_cast<T>(reinterpret_cast<int64_t *>(scalar_val)[0]); val = static_cast<T>(static_cast<int64_t *>(scalar_val)[0]);
break; break;
case kNumberTypeFloat16: case kNumberTypeFloat16:
val = static_cast<T>(reinterpret_cast<float16 *>(scalar_val)[0]); val = static_cast<T>(static_cast<float16 *>(scalar_val)[0]);
break; break;
case kNumberTypeFloat32: case kNumberTypeFloat32:
val = static_cast<T>(reinterpret_cast<float *>(scalar_val)[0]); val = static_cast<T>(static_cast<float *>(scalar_val)[0]);
break; break;
case kNumberTypeFloat64: case kNumberTypeFloat64:
val = static_cast<T>(reinterpret_cast<double *>(scalar_val)[0]); val = static_cast<T>(static_cast<double *>(scalar_val)[0]);
break; break;
case kNumberTypeBool: case kNumberTypeBool:
val = static_cast<T>(reinterpret_cast<bool *>(scalar_val)[0]); val = static_cast<T>(static_cast<bool *>(scalar_val)[0]);
break; break;
case kNumberTypeComplex64: case kNumberTypeComplex64:
val = static_cast<T>(reinterpret_cast<std::complex<float> *>(scalar_val)[0].real()); val = static_cast<T>(static_cast<std::complex<float> *>(scalar_val)[0].real());
break; break;
case kNumberTypeComplex128: case kNumberTypeComplex128:
val = static_cast<T>(reinterpret_cast<std::complex<double> *>(scalar_val)[0].real()); val = static_cast<T>(static_cast<std::complex<double> *>(scalar_val)[0].real());
break; break;
default: default:
MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. "; MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. ";
break;
} }
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) { for (size_t i = start; i < end; i++) {
@ -296,10 +306,10 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
// sparse matrix add sparse matrix // sparse matrix add sparse matrix
// input + mat1 @ mat2 // input + mat1 @ mat2
template <typename T, typename S> template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices, void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, const S *input_values, size_t input_num,
S *y_values, size_t y_num) { int64_t *y_indices, S *y_values, size_t y_num) {
// to implement m1[row][col] = vals // to implement m1[row][col] = vals
auto input_ids = reinterpret_cast<T *>(input_indices); auto input_ids = static_cast<T *>(input_indices);
this->cnt_ = input_num; this->cnt_ = input_num;
// get output vals and index addr // get output vals and index addr
auto task = [&](size_t start, size_t end) { auto task = [&](size_t start, size_t end) {
@ -315,11 +325,11 @@ void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values,
} }
template <typename T, typename S> template <typename T, typename S>
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr, void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, const S *mat2_addr, int64_t *y_indices, S *y_values, size_t y_vals_num,
int64_t mat2_col_) { int64_t mat2_col) {
// the result of mat1 @ mat2 will write to output directly // the result of mat1 @ mat2 will write to output directly
auto mat1_ids = reinterpret_cast<T *>(mat1_indices); auto mat1_ids = static_cast<T *>(mat1_indices);
std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt; std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt;
std::unordered_map<T, std::vector<T>> unrepeated; std::unordered_map<T, std::vector<T>> unrepeated;
@ -332,16 +342,16 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
T _col = mat1_ids[i + mat1_vals_num]; T _col = mat1_ids[i + mat1_vals_num];
unrepeated[_row].push_back(_col); unrepeated[_row].push_back(_col);
co_map_idx[_row][_col].push_back(mat1_values[i]); co_map_idx[_row][_col].push_back(mat1_values[i]);
for (int64_t j = 0; j < mat2_col_; j++) { for (int64_t j = 0; j < mat2_col; j++) {
if (idx_map_cnt[_row][j] == 0) { if (idx_map_cnt[_row][j] == 0) {
idx_map_cnt[_row][j] = this->cnt_; idx_map_cnt[_row][j] = SizeToUint(this->cnt_);
this->cnt_++; this->cnt_++;
} }
} }
} }
std::vector<T> res; std::vector<T> res;
for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) { for (auto it = unrepeated.begin(); it != unrepeated.end(); ++it) {
res.push_back(it->first); res.push_back(it->first);
} }
@ -353,10 +363,11 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
for (auto row_mat2 : unrepeated[row_mat1]) { for (auto row_mat2 : unrepeated[row_mat1]) {
S val = co_map_idx[row_mat1][row_mat2].back(); S val = co_map_idx[row_mat1][row_mat2].back();
co_map_idx[row_mat1][row_mat2].pop_back(); co_map_idx[row_mat1][row_mat2].pop_back();
for (int64_t j = 0; j < mat2_col_; j++) { for (int64_t j = 0; j < mat2_col; j++) {
// get val // get val
T idx = idx_map_cnt[row_mat1][j]; size_t idx = static_cast<size_t>(idx_map_cnt[row_mat1][j]);
*(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j]; int64_t row_mat2_long = static_cast<int64_t>(row_mat2);
*(y_values + idx) += val * mat2_addr[row_mat2_long * mat2_col + j];
y_indices[idx] = static_cast<int64_t>(row_mat1); y_indices[idx] = static_cast<int64_t>(row_mat1);
y_indices[idx + y_vals_num] = j; y_indices[idx + y_vals_num] = j;
} }

View File

@ -30,28 +30,30 @@ class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod {
SspaddmmCPUKernelMod() = default; SspaddmmCPUKernelMod() = default;
~SspaddmmCPUKernelMod() override = default; ~SspaddmmCPUKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void CheckParam(const CNodePtr &kernel_node); void CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr, size_t num,
const std::string &x_name) const;
void CheckParam(const CNodePtr &kernel_node) const;
template <typename T, typename S> template <typename T, typename S>
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name); void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, const std::string &x_name) const;
template <typename T> template <typename T>
void InitShape(void *input_shape, int64_t *y_shape); void InitShape(void *input_shape, int64_t *y_shape) const;
template <typename T> template <typename T>
void ClearSparseValues(T *sparse_val, size_t data_num); void ClearSparseValues(T *sparse_val, size_t data_num);
template <typename T> template <typename T>
T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid); T *ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid);
template <typename T, typename S> template <typename T, typename S>
void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values, void SparseAddSparse(void *input_indices, const S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
size_t y_num); size_t y_num);
template <typename T, typename S> template <typename T, typename S>
void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor, void SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num, const S *mat2_addr,
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col); int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t mat2_col);
TypeId output_values_dtype_{kTypeUnknown}; TypeId output_values_dtype_{kTypeUnknown};
TypeId input_indices_dtype_{kTypeUnknown}; TypeId input_indices_dtype_{kTypeUnknown};

View File

@ -724,25 +724,35 @@ class EinsumHelper {
} }
ShapeVector temp_vec(input_shapes[idx_input].begin() + idx_left, ShapeVector temp_vec(input_shapes[idx_input].begin() + idx_left,
input_shapes[idx_input].begin() + idx_shape_right + 1); input_shapes[idx_input].begin() + idx_shape_right + 1);
if (!AdjustElementMapShape(temp_vec)) {
if (element_shape_map_.find(ELL_VAL) != element_shape_map_.end()) { return false;
if (element_shape_map_[ELL_VAL] != temp_vec) {
if (temp_vec.size() == 0 || element_shape_map_[ELL_VAL].size() == 0) {
element_shape_map_[ELL_VAL] = temp_vec.size() == 0 ? element_shape_map_[ELL_VAL] : temp_vec;
} else {
MS_LOG(ERROR)
<< "For " << node_name_
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
return false;
}
}
} else {
element_shape_map_[ELL_VAL] = temp_vec;
} }
} }
} }
return true; return true;
} }
bool AdjustElementMapShape(const ShapeVector &temp_vec) {
const size_t ellipsis_val_num = 52;
auto iter = element_shape_map_.find(ellipsis_val_num);
if (iter != element_shape_map_.end()) {
ShapeVector cur_vec = iter->second;
if (cur_vec != temp_vec) {
if (temp_vec.empty() || cur_vec.empty()) {
element_shape_map_[ellipsis_val_num] = temp_vec.empty() ? cur_vec : temp_vec;
} else {
MS_LOG(ERROR)
<< "For " << node_name_
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
return false;
}
}
} else {
element_shape_map_[ellipsis_val_num] = temp_vec;
}
return true;
}
void CalAxisShape(const std::vector<size_t> &axis_val, const ShapeVector &shape_val, size_t *idx, void CalAxisShape(const std::vector<size_t> &axis_val, const ShapeVector &shape_val, size_t *idx,
ShapeVector *re_shape, std::vector<size_t> *res_trans_axis) { ShapeVector *re_shape, std::vector<size_t> *res_trans_axis) {
for (auto val : axis_val) { for (auto val : axis_val) {

View File

@ -90,7 +90,8 @@ std::vector<int64_t> GetOutputShape(const PrimitivePtr &primitive, const std::ve
} }
} else { } else {
out_d = DoubleToLong(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1)); out_d = DoubleToLong(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1));
out_h = DoubleToLong(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1)); out_h =
DoubleToLong(std::floor((in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h) / stride_h + 1));
out_w = out_w =
DoubleToLong(std::floor((in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1)); DoubleToLong(std::floor((in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1));
} }