forked from mindspore-Ecosystem/mindspore
Fix cyber models precision
This commit is contained in:
parent
3bdd744d47
commit
665ce6b7cd
|
@ -64,6 +64,14 @@ STATUS ReduceFusionMapper::Mapper(const CNodePtr &cnode) {
|
|||
dst_prim = reduce_all.GetPrim();
|
||||
} else if (mode == static_cast<int64_t>(ReduceMode::Reduce_L2)) {
|
||||
ops::LpNorm lp_norm_op;
|
||||
auto axes_ptr = src_prim->GetAttr(ops::kAxes);
|
||||
if (axes_ptr != nullptr) {
|
||||
auto axes = GetValue<std::vector<int32_t>>(axes_ptr);
|
||||
std::vector<int64_t> axes_vec;
|
||||
std::transform(axes.begin(), axes.end(), std::back_inserter(axes),
|
||||
[](int32_t x) { return static_cast<int64_t>(x); });
|
||||
lp_norm_op.set_axis(axes_vec);
|
||||
}
|
||||
dst_prim = lp_norm_op.GetPrim();
|
||||
} else if (mode == static_cast<int64_t>(ReduceMode::Reduce_Prod)) {
|
||||
dst_prim = std::make_shared<acl::DynamicReduceProd>();
|
||||
|
|
Loading…
Reference in New Issue