forked from mindspore-Ecosystem/mindspore
!19244 fix lstm build error
Merge pull request !19244 from kisnwang/fix-lstm-build-error
This commit is contained in:
commit
39659af1e8
|
@ -362,13 +362,13 @@ void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> outp
|
|||
}
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool need_convert = context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK);
|
||||
if (need_convert) {
|
||||
need_convert =
|
||||
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) != trans::kTransFormatMapOfHostToDevice.end();
|
||||
bool disable_convert = real_input_node->isa<Parameter>() || real_input_node->isa<ValueNode>();
|
||||
if (disable_convert && context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
|
||||
disable_convert =
|
||||
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end();
|
||||
}
|
||||
// if not find in host convert format map means the host has not registered the convert function of this format
|
||||
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT && !need_convert) {
|
||||
if (output_format[0] != kOpFormat_DEFAULT && disable_convert) {
|
||||
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
|
|
Loading…
Reference in New Issue