!18542 weight format fresh bug fix

Merge pull request !18542 from liubuyu/bug_fix
This commit is contained in:
i-robot 2021-06-21 15:09:24 +08:00 committed by Gitee
commit 55e646ff4d
1 changed files with 7 additions and 2 deletions

View File

@ -355,11 +355,16 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
}
void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string> &output_format,
const CNodePtr &kernel_node, size_t input_index, bool force_fresh = false) {
void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> output_format, const CNodePtr &kernel_node,
size_t input_index, bool force_fresh = false) {
if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
return;
}
// if not find in host convert format map means the host has not registered the convert function of this format
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT &&
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end()) {
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// we set special device info of a input tensor.
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);