!45683 [MS][BUG]fix bug of clip_by_norm

Merge pull request !45683 from mengyuanli/fix_bug_clip_by_norm
This commit is contained in:
i-robot 2022-11-21 03:20:10 +00:00 committed by Gitee
commit f5adc80b63
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 8 deletions

View File

@ -48,9 +48,7 @@ bool ClipByNormCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const st
const std::vector<KernelTensorPtr> &outputs) {
// Get C++ primitive and kernel_name
MS_EXCEPTION_IF_NULL(base_operator);
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
MS_EXCEPTION_IF_NULL(prim);
kernel_name_ = prim->name();
kernel_name_ = base_operator->name();
// Check whether current input and output data types are valid.
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
if (!MatchKernelAttr(kernel_attr, GetOpSupport()).first) {
@ -67,15 +65,15 @@ int ClipByNormCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
ResetResource();
// Init basic variables
InitIOShape(inputs, outputs);
InitAxisAndEpsilon(prim);
// Init the `l2_norm` reduce shape according to `axis`
l2_norm_output_shape_ = x_shape_;
auto prim = std::dynamic_pointer_cast<ops::ClipByNorm>(base_operator);
MS_EXCEPTION_IF_NULL(prim);
InitAxisAndEpsilon(prim);
(void)std::for_each(axis_.begin(), axis_.end(), [this](const size_t &idx) { l2_norm_output_shape_[idx] = 1; });
InitSizeLists();
return KRET_OK;
}
@ -109,14 +107,13 @@ std::vector<KernelAttr> ClipByNormCpuKernelMod::GetOpSupport() { return clip_by_
void ClipByNormCpuKernelMod::ResetResource() {
epsilon_ = 0.000001f;
x_dim_ = 0;
data_type_ = std::make_pair(kNumberTypeFloat32, kNumberTypeFloat32);
axis_.clear();
x_shape_.clear();
clip_norm_shape_.clear();
l2_norm_output_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void ClipByNormCpuKernelMod::InitIOShape(const std::vector<KernelTensorPtr> &inputs,
@ -161,6 +158,7 @@ void ClipByNormCpuKernelMod::InitAxisAndEpsilon(const ops::ClipByNormPtr &prim)
MS_EXCEPTION(TypeError) << "For `" << kernel_name_ << "`, the type of attribute `axis` is invalid.";
}
// Init `axis_`
axis_.clear();
if (temp_axis.empty()) {
for (size_t i = 0; i < x_dim_; ++i) {
(void)axis_.emplace_back(i); // Reduce for all dimensions.