fix some bugs of ynamicBroadcastGradientArgsCpuKernelMod code

This commit is contained in:
hanhuifeng2020 2022-09-15 16:14:03 +08:00
parent b1a32c6c89
commit 8672fe4afb
2 changed files with 2 additions and 9 deletions

View File

@ -163,7 +163,7 @@ bool DynamicBroadcastGradientArgsCpuKernelMod::Init(const BaseOperatorPtr &base_
if (!MatchKernelFunc(base_operator, inputs, outputs)) { if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false; return false;
} }
is_need_retrieve_output_shape_ = true;
return true; return true;
} }
@ -177,7 +177,6 @@ int DynamicBroadcastGradientArgsCpuKernelMod::Resize(const BaseOperatorPtr &base
} }
// get input_shape // get input_shape
outputs_ = outputs; outputs_ = outputs;
is_need_retrieve_output_shape_ = true;
return static_cast<int>(KRET_OK); return static_cast<int>(KRET_OK);
} }

View File

@ -31,7 +31,7 @@ namespace kernel {
class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod, class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod,
public MatchKernelHelper<DynamicBroadcastGradientArgsCpuKernelMod> { public MatchKernelHelper<DynamicBroadcastGradientArgsCpuKernelMod> {
public: public:
DynamicBroadcastGradientArgsCpuKernelMod() : r0_size_(0), r1_size_(0) { ResetResource(); } DynamicBroadcastGradientArgsCpuKernelMod() : r0_size_(0), r1_size_(0) {}
~DynamicBroadcastGradientArgsCpuKernelMod() override = default; ~DynamicBroadcastGradientArgsCpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs, bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -51,12 +51,6 @@ class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod,
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); } std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
void ResetResource() noexcept {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelTensorPtr> GetOutputs() override { std::vector<KernelTensorPtr> GetOutputs() override {
ShapeVector r0_shape{SizeToLong(r0_size_)}; ShapeVector r0_shape{SizeToLong(r0_size_)};
ShapeVector r1_shape{SizeToLong(r1_size_)}; ShapeVector r1_shape{SizeToLong(r1_size_)};