forked from OSSInnovation/mindspore
!3196 fix precision error with fp16 input on pynative mode
Merge pull request !3196 from chujinjin/fix_precision_error_with_fp16_input
This commit is contained in:
commit
ab53809f2c
|
@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
|
|||
}
|
||||
} // namespace
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
|
||||
const TypeId &type_id) {
|
||||
MS_EXCEPTION_IF_NULL(trans_data);
|
||||
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
|
||||
MS_EXCEPTION_IF_NULL(ori_build_info);
|
||||
|
@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
|
|||
builder->SetInputReshapeType({reshape_type});
|
||||
builder->SetOutputReshapeType({reshape_type});
|
||||
builder->SetOutputsFormat({output_format});
|
||||
if (type_id != kTypeUnknown) {
|
||||
builder->SetOutputsDeviceType({type_id});
|
||||
builder->SetInputsDeviceType({type_id});
|
||||
}
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -86,7 +86,8 @@ class OpFinder {
|
|||
using OpFinderPtr = std::shared_ptr<OpFinder>;
|
||||
|
||||
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
|
||||
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
|
||||
const TypeId &type_id = kTypeUnknown);
|
||||
|
||||
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
|
||||
const bool need_padding, const std::string &op_name);
|
||||
|
|
|
@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
|||
if (origin_format != cur_format && cur_shape.size() > 1) {
|
||||
auto kernel_select = std::make_shared<KernelSelect>();
|
||||
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
|
||||
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
|
||||
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
|
||||
final_index = 0;
|
||||
MS_EXCEPTION_IF_NULL(final_node);
|
||||
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
|
||||
|
|
Loading…
Reference in New Issue