forked from OSSInnovation/mindspore
!3196 fix precision error with fp16 input on pynative mode
Merge pull request !3196 from chujinjin/fix_precision_error_with_fp16_input
This commit is contained in:
commit
ab53809f2c
|
@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
|
||||||
|
const TypeId &type_id) {
|
||||||
MS_EXCEPTION_IF_NULL(trans_data);
|
MS_EXCEPTION_IF_NULL(trans_data);
|
||||||
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
||||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||||
|
@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
||||||
builder->SetInputReshapeType({reshape_type});
|
builder->SetInputReshapeType({reshape_type});
|
||||||
builder->SetOutputReshapeType({reshape_type});
|
builder->SetOutputReshapeType({reshape_type});
|
||||||
builder->SetOutputsFormat({output_format});
|
builder->SetOutputsFormat({output_format});
|
||||||
|
if (type_id != kTypeUnknown) {
|
||||||
|
builder->SetOutputsDeviceType({type_id});
|
||||||
|
builder->SetInputsDeviceType({type_id});
|
||||||
|
}
|
||||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,8 @@ class OpFinder {
|
||||||
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
||||||
|
|
||||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
|
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
|
||||||
|
const TypeId &type_id = kTypeUnknown);
|
||||||
|
|
||||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||||
const bool need_padding, const std::string &op_name);
|
const bool need_padding, const std::string &op_name);
|
||||||
|
|
|
@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
||||||
if (origin_format != cur_format && cur_shape.size() > 1) {
|
if (origin_format != cur_format && cur_shape.size() > 1) {
|
||||||
auto kernel_select = std::make_shared<KernelSelect>();
|
auto kernel_select = std::make_shared<KernelSelect>();
|
||||||
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
||||||
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
|
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
|
||||||
final_index = 0;
|
final_index = 0;
|
||||||
MS_EXCEPTION_IF_NULL(final_node);
|
MS_EXCEPTION_IF_NULL(final_node);
|
||||||
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
||||||
|
|
Loading…
Reference in New Issue