!45683 [MS][BUG]fix bug of clip_by_norm
Merge pull request !45683 from mengyuanli/fix_bug_clip_by_norm
This commit is contained in:
commit
f5adc80b63
|
@ -48,9 +48,7 @@ bool ClipByNormCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
// Get C++ primitive and kernel_name
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = prim->name();
|
||||
kernel_name_ = base_operator->name();
|
||||
// Check whether current input and output data types are valid.
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
if (!MatchKernelAttr(kernel_attr, GetOpSupport()).first) {
|
||||
|
@ -67,15 +65,15 @@ int ClipByNormCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
|
||||
ResetResource();
|
||||
// Init basic variables
|
||||
InitIOShape(inputs, outputs);
|
||||
InitAxisAndEpsilon(prim);
|
||||
// Init the `l2_norm` reduce shape according to `axis`
|
||||
l2_norm_output_shape_ = x_shape_;
|
||||
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
InitAxisAndEpsilon(prim);
|
||||
(void)std::for_each(axis_.begin(), axis_.end(), [this](const size_t &idx) { l2_norm_output_shape_[idx] = 1; });
|
||||
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -109,14 +107,13 @@ std::vector<KernelAttr> ClipByNormCpuKernelMod::GetOpSupport() { return clip_by_
|
|||
void ClipByNormCpuKernelMod::ResetResource() {
|
||||
epsilon_ = 0.000001f;
|
||||
x_dim_ = 0;
|
||||
data_type_ = std::make_pair(kNumberTypeFloat32, kNumberTypeFloat32);
|
||||
axis_.clear();
|
||||
x_shape_.clear();
|
||||
clip_norm_shape_.clear();
|
||||
l2_norm_output_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void ClipByNormCpuKernelMod::InitIOShape(const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -161,6 +158,7 @@ void ClipByNormCpuKernelMod::InitAxisAndEpsilon(const ops::ClipByNormPtr &prim)
|
|||
MS_EXCEPTION(TypeError) << "For `" << kernel_name_ << "`, the type of attribute `axis` is invalid.";
|
||||
}
|
||||
// Init `axis_`
|
||||
axis_.clear();
|
||||
if (temp_axis.empty()) {
|
||||
for (size_t i = 0; i < x_dim_; ++i) {
|
||||
(void)axis_.emplace_back(i); // Reduce for all dimensions.
|
||||
|
|
Loading…
Reference in New Issue