!40700 fixed median dynamic rank for gpu

Merge pull request !40700 from liuchao/median
This commit is contained in:
i-robot 2022-08-23 03:18:02 +00:00 committed by Gitee
commit e9c85db0ed
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 14 additions and 20 deletions

View File

@ -62,7 +62,19 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
}
global_median_ = kernel_ptr->get_global_median();
keep_dims_ = kernel_ptr->get_keep_dims();
axis_ = kernel_ptr->get_axis();
attr_axis_ = kernel_ptr->get_axis();
return true;
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != 0) {
return ret;
}
axis_ = attr_axis_;
input_shape_ = inputs[0]->GetShapeVector();
if (global_median_) {
int input_size = 1;
@ -96,25 +108,6 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
axis_ += dims;
}
}
return true;
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != 0) {
return ret;
}
input_shape_ = inputs[0]->GetShapeVector();
if (global_median_) {
int input_size = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input_size *= input_shape_[i];
}
input_shape_.clear();
input_shape_.push_back(input_size);
}
return KRET_OK;
}
@ -131,6 +124,7 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
private:
bool global_median_;
bool keep_dims_;
int64_t attr_axis_;
int64_t axis_;
std::vector<int64_t> input_shape_;
};