!40700 fixed median dynamic rank for gpu
Merge pull request !40700 from liuchao/median
This commit is contained in:
commit
e9c85db0ed
|
@ -62,7 +62,19 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
global_median_ = kernel_ptr->get_global_median();
|
||||
keep_dims_ = kernel_ptr->get_keep_dims();
|
||||
axis_ = kernel_ptr->get_axis();
|
||||
attr_axis_ = kernel_ptr->get_axis();
|
||||
return true;
|
||||
}
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
axis_ = attr_axis_;
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
if (global_median_) {
|
||||
int input_size = 1;
|
||||
|
@ -96,25 +108,6 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
axis_ += dims;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
if (global_median_) {
|
||||
int input_size = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input_size *= input_shape_[i];
|
||||
}
|
||||
input_shape_.clear();
|
||||
input_shape_.push_back(input_size);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
|
@ -131,6 +124,7 @@ class MedianGpuKernelMod : public NativeGpuKernelMod {
|
|||
private:
|
||||
bool global_median_;
|
||||
bool keep_dims_;
|
||||
int64_t attr_axis_;
|
||||
int64_t axis_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue