clean code
This commit is contained in:
parent
2b35007b0c
commit
886610fd80
|
@ -27,15 +27,26 @@ constexpr size_t kInputsNum = 9;
|
|||
constexpr size_t kOutputsNum = 3;
|
||||
constexpr char kKernelName[] = "Sspaddmm";
|
||||
|
||||
#define CHECKSPARSEINDICES(dtype1, dtype2, indices, shapes, num, x_name) \
|
||||
if (dtype1 == kNumberTypeInt32) { \
|
||||
CheckSparseIndicesLegal<int32_t, int32_t>(indices, shapes, num, x_name); \
|
||||
} else { \
|
||||
CheckSparseIndicesLegal<int64_t, int64_t>(indices, shapes, num, x_name); \
|
||||
}
|
||||
constexpr size_t kOutputShapeIndex = 2;
|
||||
constexpr size_t kInputShapeIndex = 2;
|
||||
constexpr size_t kMat1IndiceIndex = 3;
|
||||
constexpr size_t kMat1ValueIndex = 4;
|
||||
constexpr size_t kMat1ShapeIndex = 5;
|
||||
constexpr size_t kMat2Index = 6;
|
||||
constexpr size_t kAlphaIndex = 7;
|
||||
constexpr size_t kBetaIndex = 8;
|
||||
} // namespace
|
||||
|
||||
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) {
|
||||
void SspaddmmCPUKernelMod::CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr,
|
||||
size_t num, const std::string &x_name) const {
|
||||
if (indices_dtype == kNumberTypeInt32) {
|
||||
CheckSparseIndicesLegal<int32_t, int32_t>(indices_addr, shape_addr, num, x_name);
|
||||
} else {
|
||||
CheckSparseIndicesLegal<int64_t, int64_t>(indices_addr, shape_addr, num, x_name);
|
||||
}
|
||||
}
|
||||
|
||||
void SspaddmmCPUKernelMod::CheckParam(const CNodePtr &kernel_node) const {
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kKernelName);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -59,14 +70,14 @@ void SspaddmmCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
auto y_indices_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, kIndex0);
|
||||
auto mat2_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kIndex6);
|
||||
|
||||
input_values_num_ = input_indices_shape[1];
|
||||
mat1_values_num_ = mat1_indices_shape[1];
|
||||
y_values_num_ = y_indices_shape[1];
|
||||
mat2_row_ = mat2_shape[0];
|
||||
mat2_col_ = mat2_shape[1];
|
||||
input_values_num_ = LongToSize(input_indices_shape[1]);
|
||||
mat1_values_num_ = LongToSize(mat1_indices_shape[1]);
|
||||
y_values_num_ = LongToSize(y_indices_shape[1]);
|
||||
mat2_row_ = LongToSize(mat2_shape[0]);
|
||||
mat2_col_ = LongToSize(mat2_shape[1]);
|
||||
}
|
||||
|
||||
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
switch (output_values_dtype_) {
|
||||
case kNumberTypeUInt8: {
|
||||
|
@ -107,38 +118,39 @@ bool SspaddmmCPUKernelMod::Launch(const std::vector<AddressPtr> &inputs, const s
|
|||
template <typename T>
|
||||
void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto input_indices_addr = inputs[0]->addr;
|
||||
auto input_values_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto input_shape_addr = inputs[2]->addr;
|
||||
auto mat1_indices_addr = inputs[3]->addr;
|
||||
auto mat1_values_addr = reinterpret_cast<T *>(inputs[4]->addr);
|
||||
auto mat1_shape_addr = inputs[5]->addr;
|
||||
auto mat2_addr = reinterpret_cast<T *>(inputs[6]->addr);
|
||||
auto alpha_val_addr = inputs[7]->addr;
|
||||
auto beta_val_addr = inputs[8]->addr;
|
||||
auto y_indices_addr = reinterpret_cast<int64_t *>(outputs[0]->addr);
|
||||
auto y_values_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
auto y_shape_addr = reinterpret_cast<int64_t *>(outputs[2]->addr);
|
||||
CHECKSPARSEINDICES(input_indices_dtype_, input_shape_dtype_, input_indices_addr, input_shape_addr, input_values_num_,
|
||||
"x1");
|
||||
CHECKSPARSEINDICES(mat1_indices_dtype_, mat1_shape_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_,
|
||||
"x2");
|
||||
auto input_values_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto input_shape_addr = inputs[kInputShapeIndex]->addr;
|
||||
auto mat1_indices_addr = inputs[kMat1IndiceIndex]->addr;
|
||||
auto mat1_values_addr = GetDeviceAddress<T>(inputs, kMat1ValueIndex);
|
||||
auto mat1_shape_addr = inputs[kMat1ShapeIndex]->addr;
|
||||
auto mat2_addr = GetDeviceAddress<T>(inputs, kMat2Index);
|
||||
auto alpha_val_addr = inputs[kAlphaIndex]->addr;
|
||||
auto beta_val_addr = inputs[kBetaIndex]->addr;
|
||||
auto y_indices_addr = GetDeviceAddress<int64_t>(outputs, 0);
|
||||
auto y_values_addr = GetDeviceAddress<T>(outputs, 1);
|
||||
auto y_shape_addr = GetDeviceAddress<int64_t>(outputs, kOutputShapeIndex);
|
||||
|
||||
const std::string x1_name = "x1";
|
||||
const std::string x2_name = "x2";
|
||||
CheckSparseIndices(input_indices_dtype_, input_indices_addr, input_shape_addr, input_values_num_, x1_name);
|
||||
CheckSparseIndices(mat1_indices_dtype_, mat1_indices_addr, mat1_shape_addr, mat1_values_num_, x2_name);
|
||||
|
||||
int64_t mat1_row, mat1_col, input_row, input_col;
|
||||
if (mat1_shape_dtype_ == kNumberTypeInt32) {
|
||||
auto mat1_shape_val = reinterpret_cast<int32_t *>(mat1_shape_addr);
|
||||
auto mat1_shape_val = static_cast<int32_t *>(mat1_shape_addr);
|
||||
mat1_row = static_cast<int64_t>(mat1_shape_val[0]);
|
||||
mat1_col = static_cast<int64_t>(mat1_shape_val[1]);
|
||||
} else {
|
||||
auto mat1_shape_val = reinterpret_cast<int64_t *>(mat1_shape_addr);
|
||||
auto mat1_shape_val = static_cast<int64_t *>(mat1_shape_addr);
|
||||
mat1_row = mat1_shape_val[0];
|
||||
mat1_col = mat1_shape_val[1];
|
||||
}
|
||||
if (input_shape_dtype_ == kNumberTypeInt32) {
|
||||
auto input_shape_val = reinterpret_cast<int32_t *>(input_shape_addr);
|
||||
auto input_shape_val = static_cast<int32_t *>(input_shape_addr);
|
||||
input_row = static_cast<int64_t>(input_shape_val[0]);
|
||||
input_col = static_cast<int64_t>(input_shape_val[1]);
|
||||
} else {
|
||||
auto input_shape_val = reinterpret_cast<int64_t *>(input_shape_addr);
|
||||
auto input_shape_val = static_cast<int64_t *>(input_shape_addr);
|
||||
input_row = input_shape_val[0];
|
||||
input_col = input_shape_val[1];
|
||||
}
|
||||
|
@ -173,22 +185,21 @@ void SspaddmmCPUKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, c
|
|||
SparseAddSparse<int64_t, T>(input_indices_addr, input_values_addr_bak, input_values_num_, y_indices_addr,
|
||||
y_values_addr, y_values_num_);
|
||||
}
|
||||
auto row = mat1_col;
|
||||
auto col = input_col;
|
||||
if (mat1_indices_dtype_ == kNumberTypeInt32) {
|
||||
SparseMulDense<int32_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
|
||||
y_values_addr, y_values_num_, row, col);
|
||||
y_values_addr, y_values_num_, col);
|
||||
} else {
|
||||
SparseMulDense<int64_t, T>(mat1_indices_addr, mat1_values_addr_bak, mat1_values_num_, mat2_addr, y_indices_addr,
|
||||
y_values_addr, y_values_num_, row, col);
|
||||
y_values_addr, y_values_num_, col);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num,
|
||||
std::string x_name) {
|
||||
auto indices_val = reinterpret_cast<T *>(indices_addr);
|
||||
auto shape_val = reinterpret_cast<S *>(shape_addr);
|
||||
const std::string &x_name) const {
|
||||
auto indices_val = static_cast<T *>(indices_addr);
|
||||
auto shape_val = static_cast<S *>(shape_addr);
|
||||
int shape_num = 2;
|
||||
for (int i = 0; i < shape_num; i++) {
|
||||
if (shape_val[i] <= 0) {
|
||||
|
@ -210,8 +221,8 @@ void SspaddmmCPUKernelMod::CheckSparseIndicesLegal(void *indices_addr, void *sha
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) {
|
||||
auto input_shape_val = reinterpret_cast<T *>(input_shape);
|
||||
void SspaddmmCPUKernelMod::InitShape(void *input_shape, int64_t *y_shape) const {
|
||||
auto input_shape_val = static_cast<T *>(input_shape);
|
||||
size_t shape_num = 2;
|
||||
for (size_t i = 0; i < shape_num; i++) {
|
||||
y_shape[i] = static_cast<int64_t>(input_shape_val[i]);
|
||||
|
@ -231,7 +242,7 @@ void SspaddmmCPUKernelMod::ClearSparseValues(T *sparse_val, size_t data_num) {
|
|||
|
||||
// scalar * sparse matrix for beta * input alpha * mat1
|
||||
template <typename T>
|
||||
T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid) {
|
||||
T *SspaddmmCPUKernelMod::ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid) {
|
||||
T val;
|
||||
if (!(data_num > 0)) {
|
||||
MS_EXCEPTION(ValueError) << "For Sspaddmm, datanum value error. ";
|
||||
|
@ -239,50 +250,49 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
|
|||
T *sparse_val_bak = new T[data_num];
|
||||
switch (tid) {
|
||||
case kNumberTypeUInt8:
|
||||
val = static_cast<T>(reinterpret_cast<uint8_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<uint8_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt16:
|
||||
val = static_cast<T>(reinterpret_cast<uint16_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<uint16_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt32:
|
||||
val = static_cast<T>(reinterpret_cast<uint32_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<uint32_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeUInt64:
|
||||
val = static_cast<T>(reinterpret_cast<uint64_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<uint64_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt8:
|
||||
val = static_cast<T>(reinterpret_cast<int8_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<int8_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt16:
|
||||
val = static_cast<T>(reinterpret_cast<int16_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<int16_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
val = static_cast<T>(reinterpret_cast<int32_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<int32_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
val = static_cast<T>(reinterpret_cast<int64_t *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<int64_t *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
val = static_cast<T>(reinterpret_cast<float16 *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<float16 *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
val = static_cast<T>(reinterpret_cast<float *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<float *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
val = static_cast<T>(reinterpret_cast<double *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<double *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeBool:
|
||||
val = static_cast<T>(reinterpret_cast<bool *>(scalar_val)[0]);
|
||||
val = static_cast<T>(static_cast<bool *>(scalar_val)[0]);
|
||||
break;
|
||||
case kNumberTypeComplex64:
|
||||
val = static_cast<T>(reinterpret_cast<std::complex<float> *>(scalar_val)[0].real());
|
||||
val = static_cast<T>(static_cast<std::complex<float> *>(scalar_val)[0].real());
|
||||
break;
|
||||
case kNumberTypeComplex128:
|
||||
val = static_cast<T>(reinterpret_cast<std::complex<double> *>(scalar_val)[0].real());
|
||||
val = static_cast<T>(static_cast<std::complex<double> *>(scalar_val)[0].real());
|
||||
break;
|
||||
default:
|
||||
MS_EXCEPTION(TypeError) << "For Sspaddmm, dtype not support. ";
|
||||
break;
|
||||
}
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -296,10 +306,10 @@ T *SspaddmmCPUKernelMod::ScalarSparseMul(T *sparse_val, void *scalar_val, size_t
|
|||
// sparse matrix add sparse matrix
|
||||
// input + mat1 @ mat2
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values, size_t input_num, int64_t *y_indices,
|
||||
S *y_values, size_t y_num) {
|
||||
void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, const S *input_values, size_t input_num,
|
||||
int64_t *y_indices, S *y_values, size_t y_num) {
|
||||
// to implement m1[row][col] = vals
|
||||
auto input_ids = reinterpret_cast<T *>(input_indices);
|
||||
auto input_ids = static_cast<T *>(input_indices);
|
||||
this->cnt_ = input_num;
|
||||
// get output vals and index addr
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
|
@ -315,11 +325,11 @@ void SspaddmmCPUKernelMod::SparseAddSparse(void *input_indices, S *input_values,
|
|||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_addr,
|
||||
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row,
|
||||
int64_t mat2_col_) {
|
||||
void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num,
|
||||
const S *mat2_addr, int64_t *y_indices, S *y_values, size_t y_vals_num,
|
||||
int64_t mat2_col) {
|
||||
// the result of mat1 @ mat2 will write to output directly
|
||||
auto mat1_ids = reinterpret_cast<T *>(mat1_indices);
|
||||
auto mat1_ids = static_cast<T *>(mat1_indices);
|
||||
|
||||
std::unordered_map<T, std::unordered_map<int64_t, uint32_t>> idx_map_cnt;
|
||||
std::unordered_map<T, std::vector<T>> unrepeated;
|
||||
|
@ -332,16 +342,16 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
|
|||
T _col = mat1_ids[i + mat1_vals_num];
|
||||
unrepeated[_row].push_back(_col);
|
||||
co_map_idx[_row][_col].push_back(mat1_values[i]);
|
||||
for (int64_t j = 0; j < mat2_col_; j++) {
|
||||
for (int64_t j = 0; j < mat2_col; j++) {
|
||||
if (idx_map_cnt[_row][j] == 0) {
|
||||
idx_map_cnt[_row][j] = this->cnt_;
|
||||
idx_map_cnt[_row][j] = SizeToUint(this->cnt_);
|
||||
this->cnt_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<T> res;
|
||||
for (auto it = unrepeated.begin(); it != unrepeated.end(); it++) {
|
||||
for (auto it = unrepeated.begin(); it != unrepeated.end(); ++it) {
|
||||
res.push_back(it->first);
|
||||
}
|
||||
|
||||
|
@ -353,10 +363,11 @@ void SspaddmmCPUKernelMod::SparseMulDense(void *mat1_indices, S *mat1_values, si
|
|||
for (auto row_mat2 : unrepeated[row_mat1]) {
|
||||
S val = co_map_idx[row_mat1][row_mat2].back();
|
||||
co_map_idx[row_mat1][row_mat2].pop_back();
|
||||
for (int64_t j = 0; j < mat2_col_; j++) {
|
||||
for (int64_t j = 0; j < mat2_col; j++) {
|
||||
// get val
|
||||
T idx = idx_map_cnt[row_mat1][j];
|
||||
*(y_values + idx) += val * mat2_addr[row_mat2 * mat2_col_ + j];
|
||||
size_t idx = static_cast<size_t>(idx_map_cnt[row_mat1][j]);
|
||||
int64_t row_mat2_long = static_cast<int64_t>(row_mat2);
|
||||
*(y_values + idx) += val * mat2_addr[row_mat2_long * mat2_col + j];
|
||||
y_indices[idx] = static_cast<int64_t>(row_mat1);
|
||||
y_indices[idx + y_vals_num] = j;
|
||||
}
|
||||
|
|
|
@ -30,28 +30,30 @@ class SspaddmmCPUKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
SspaddmmCPUKernelMod() = default;
|
||||
~SspaddmmCPUKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
void CheckSparseIndices(const TypeId &indices_dtype, void *indices_addr, void *shape_addr, size_t num,
|
||||
const std::string &x_name) const;
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
template <typename T, typename S>
|
||||
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, std::string x_name);
|
||||
void CheckSparseIndicesLegal(void *indices_addr, void *shape_addr, size_t num, const std::string &x_name) const;
|
||||
template <typename T>
|
||||
void InitShape(void *input_shape, int64_t *y_shape);
|
||||
void InitShape(void *input_shape, int64_t *y_shape) const;
|
||||
template <typename T>
|
||||
void ClearSparseValues(T *sparse_val, size_t data_num);
|
||||
template <typename T>
|
||||
T *ScalarSparseMul(T *sparse_val, void *scalar_val, size_t data_num, TypeId tid);
|
||||
T *ScalarSparseMul(const T *sparse_val, void *scalar_val, size_t data_num, const TypeId &tid);
|
||||
template <typename T, typename S>
|
||||
void SparseAddSparse(void *input_indices, S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
|
||||
void SparseAddSparse(void *input_indices, const S *inut_values, size_t input_num, int64_t *y_indices, S *y_values,
|
||||
size_t y_num);
|
||||
template <typename T, typename S>
|
||||
void SparseMulDense(void *mat1_indices, S *mat1_values, size_t mat1_vals_num, S *mat2_values_tensor,
|
||||
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t row, int64_t mat2_col);
|
||||
void SparseMulDense(void *mat1_indices, const S *mat1_values, size_t mat1_vals_num, const S *mat2_addr,
|
||||
int64_t *y_indices, S *y_values, size_t y_vals_num, int64_t mat2_col);
|
||||
|
||||
TypeId output_values_dtype_{kTypeUnknown};
|
||||
TypeId input_indices_dtype_{kTypeUnknown};
|
||||
|
|
|
@ -724,25 +724,35 @@ class EinsumHelper {
|
|||
}
|
||||
ShapeVector temp_vec(input_shapes[idx_input].begin() + idx_left,
|
||||
input_shapes[idx_input].begin() + idx_shape_right + 1);
|
||||
|
||||
if (element_shape_map_.find(ELL_VAL) != element_shape_map_.end()) {
|
||||
if (element_shape_map_[ELL_VAL] != temp_vec) {
|
||||
if (temp_vec.size() == 0 || element_shape_map_[ELL_VAL].size() == 0) {
|
||||
element_shape_map_[ELL_VAL] = temp_vec.size() == 0 ? element_shape_map_[ELL_VAL] : temp_vec;
|
||||
} else {
|
||||
MS_LOG(ERROR)
|
||||
<< "For " << node_name_
|
||||
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
element_shape_map_[ELL_VAL] = temp_vec;
|
||||
if (!AdjustElementMapShape(temp_vec)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AdjustElementMapShape(const ShapeVector &temp_vec) {
|
||||
const size_t ellipsis_val_num = 52;
|
||||
auto iter = element_shape_map_.find(ellipsis_val_num);
|
||||
if (iter != element_shape_map_.end()) {
|
||||
ShapeVector cur_vec = iter->second;
|
||||
if (cur_vec != temp_vec) {
|
||||
if (temp_vec.empty() || cur_vec.empty()) {
|
||||
element_shape_map_[ellipsis_val_num] = temp_vec.empty() ? cur_vec : temp_vec;
|
||||
} else {
|
||||
MS_LOG(ERROR)
|
||||
<< "For " << node_name_
|
||||
<< ", the same ellipsis in equation can only represent the same dimension in inputs, but it does not.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
element_shape_map_[ellipsis_val_num] = temp_vec;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CalAxisShape(const std::vector<size_t> &axis_val, const ShapeVector &shape_val, size_t *idx,
|
||||
ShapeVector *re_shape, std::vector<size_t> *res_trans_axis) {
|
||||
for (auto val : axis_val) {
|
||||
|
|
|
@ -90,7 +90,8 @@ std::vector<int64_t> GetOutputShape(const PrimitivePtr &primitive, const std::ve
|
|||
}
|
||||
} else {
|
||||
out_d = DoubleToLong(std::floor((in_d + pad_list[0] + pad_list[1] - kernel_d) / stride_d + 1));
|
||||
out_h = DoubleToLong(std::floor((in_h + pad_list[2] + pad_list[3] - kernel_h) / stride_h + 1));
|
||||
out_h =
|
||||
DoubleToLong(std::floor((in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h) / stride_h + 1));
|
||||
out_w =
|
||||
DoubleToLong(std::floor((in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue