!19244 fix lstm build error

Merge pull request !19244 from kisnwang/fix-lstm-build-error
This commit is contained in:
i-robot 2021-07-02 06:29:33 +00:00 committed by Gitee
commit 39659af1e8
1 changed files with 5 additions and 5 deletions

View File

@ -362,13 +362,13 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool need_convert = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
if (need_convert) {
need_convert =
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end();
bool disable_convert = real_input_node->isa<Parameter>() || real_input_node->isa<ValueNode>();
if (disable_convert && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
disable_convert =
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end();
}
// if not find in host convert format map means the host has not registered the convert function of this format
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT && !need_convert) {
if (output_format[0] != kOpFormat_DEFAULT && disable_convert) {
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();