forked from mindspore-Ecosystem/mindspore
!18542 weight format fresh bug fix
Merge pull request !18542 from liubuyu/bug_fix
This commit is contained in:
commit
55e646ff4d
|
@ -355,11 +355,16 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) {
|
|||
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), kernel_node.get());
|
||||
}
|
||||
|
||||
void SetWeightFormat(const AnfNodePtr &real_input_node, const std::vector<string> &output_format,
|
||||
const CNodePtr &kernel_node, size_t input_index, bool force_fresh = false) {
|
||||
void SetWeightFormat(const AnfNodePtr &real_input_node, std::vector<string> output_format, const CNodePtr &kernel_node,
|
||||
size_t input_index, bool force_fresh = false) {
|
||||
if (real_input_node->isa<CNode>() || AnfAlgo::OutputAddrExist(real_input_node, 0)) {
|
||||
return;
|
||||
}
|
||||
// if not find in host convert format map means the host has not registered the convert function of this format
|
||||
if (real_input_node->isa<Parameter>() && output_format[0] != kOpFormat_DEFAULT &&
|
||||
trans::kTransFormatMapOfHostToDevice.find(output_format[0]) == trans::kTransFormatMapOfHostToDevice.end()) {
|
||||
output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// we set special device info of a input tensor.
|
||||
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
|
||||
|
|
Loading…
Reference in New Issue