forked from mindspore-Ecosystem/mindspore
!11997 update io_format to format
From: @liubuyu Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
f8fc047f2b
|
@ -65,8 +65,8 @@ void SetTransNodeAttr(const CNodePtr &trans_node) {
|
|||
|
||||
std::string InitDefaultFormat(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format");
|
||||
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr(kAttrFormat, node->cast<CNodePtr>())) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, kAttrFormat);
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
return kOpFormat_NCDHW;
|
||||
}
|
||||
|
@ -127,11 +127,11 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
|
|||
std::string output_format = AnfAlgo::GetOutputFormat(node, 0);
|
||||
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
if (output_format == kOpFormat_NC1KHKWHWC0) {
|
||||
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
|
||||
MS_LOG(EXCEPTION) << "Got the hw format " << output_format << "when insert the transdata node "
|
||||
<< node->DebugString() << " trace: " << trace::DumpSourceLines(node);
|
||||
}
|
||||
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
|
||||
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
|
||||
MS_LOG(DEBUG) << "Inserted transdata " << output_format << " to default , index :0";
|
||||
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
|
||||
}
|
||||
return node;
|
||||
|
@ -364,7 +364,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
|
|||
const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index);
|
||||
const std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(prev_node.first, prev_node.second);
|
||||
// In graph kernel, we check parameter,
|
||||
// the eliminate pass will not eliminate this case, so we just do not insert the noused cast.
|
||||
// the eliminate pass will not eliminate this case, so we just do not insert the no used cast.
|
||||
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode<tensor::Tensor>(cur_input)) {
|
||||
new_inputs.push_back(cur_input);
|
||||
} else if (TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); origin_type != device_type) {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <set>
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
|
@ -400,8 +401,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
|
|||
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
|
||||
}
|
||||
SetKernelInfoForNode(cnode);
|
||||
if (AnfAlgo::HasNodeAttr("io_format", cnode)) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
|
||||
if (AnfAlgo::HasNodeAttr(kAttrFormat, cnode)) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFormat);
|
||||
if (attr == kOpFormat_NCDHW) {
|
||||
ResetInFormat(cnode, kOpFormat_NCDHW);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue