fix some bugs of ynamicBroadcastGradientArgsCpuKernelMod code
This commit is contained in:
parent
b1a32c6c89
commit
8672fe4afb
|
@ -163,7 +163,7 @@ bool DynamicBroadcastGradientArgsCpuKernelMod::Init(const BaseOperatorPtr &base_
|
||||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
is_need_retrieve_output_shape_ = true;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,7 +177,6 @@ int DynamicBroadcastGradientArgsCpuKernelMod::Resize(const BaseOperatorPtr &base
|
||||||
}
|
}
|
||||||
// get input_shape
|
// get input_shape
|
||||||
outputs_ = outputs;
|
outputs_ = outputs;
|
||||||
is_need_retrieve_output_shape_ = true;
|
|
||||||
|
|
||||||
return static_cast<int>(KRET_OK);
|
return static_cast<int>(KRET_OK);
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace kernel {
|
||||||
class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod,
|
class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod,
|
||||||
public MatchKernelHelper<DynamicBroadcastGradientArgsCpuKernelMod> {
|
public MatchKernelHelper<DynamicBroadcastGradientArgsCpuKernelMod> {
|
||||||
public:
|
public:
|
||||||
DynamicBroadcastGradientArgsCpuKernelMod() : r0_size_(0), r1_size_(0) { ResetResource(); }
|
DynamicBroadcastGradientArgsCpuKernelMod() : r0_size_(0), r1_size_(0) {}
|
||||||
~DynamicBroadcastGradientArgsCpuKernelMod() override = default;
|
~DynamicBroadcastGradientArgsCpuKernelMod() override = default;
|
||||||
|
|
||||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
@ -51,12 +51,6 @@ class DynamicBroadcastGradientArgsCpuKernelMod : public NativeCpuKernelMod,
|
||||||
|
|
||||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||||
|
|
||||||
void ResetResource() noexcept {
|
|
||||||
input_size_list_.clear();
|
|
||||||
output_size_list_.clear();
|
|
||||||
workspace_size_list_.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<KernelTensorPtr> GetOutputs() override {
|
std::vector<KernelTensorPtr> GetOutputs() override {
|
||||||
ShapeVector r0_shape{SizeToLong(r0_size_)};
|
ShapeVector r0_shape{SizeToLong(r0_size_)};
|
||||||
ShapeVector r1_shape{SizeToLong(r1_size_)};
|
ShapeVector r1_shape{SizeToLong(r1_size_)};
|
||||||
|
|
Loading…
Reference in New Issue