forked from mindspore-Ecosystem/mindspore
!5003 GPU inset transpose pass optimize
Merge pull request !5003 from VectorSL/opt-transpose
This commit is contained in:
commit
eeb3b1a272
|
@ -33,7 +33,25 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string
|
|||
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
|
||||
return {0, 3, 1, 2};
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format;
|
||||
MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format;
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose can be replaceed by nop reshape in some situations.
|
||||
// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1}
|
||||
// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2}
|
||||
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) {
|
||||
if (out_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D.";
|
||||
}
|
||||
std::vector<int> perm1 = {0, 2, 3, 1};
|
||||
std::vector<int> perm2 = {0, 3, 1, 2};
|
||||
if (transpose_perm == perm1) {
|
||||
return (out_shape[1] == 1 && out_shape[2] == 1);
|
||||
} else if (transpose_perm == perm2) {
|
||||
return (out_shape[2] == 1 && out_shape[3] == 1);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,8 +74,16 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|||
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
||||
int used_node_index, const std::vector<int> &transpose_perm) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// 1.Create a transpose node.
|
||||
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
||||
// 0.Judge whether it is a fake transpose
|
||||
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
|
||||
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
|
||||
// 1.Create a transpose node or a fake transpose node:reshape.
|
||||
mindspore::PrimitivePtr transpose_prim;
|
||||
if (is_fake) {
|
||||
transpose_prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
|
||||
} else {
|
||||
transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(transpose_prim);
|
||||
// 2.Set the input of transpose.
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
||||
|
@ -66,7 +92,9 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
||||
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
if (!is_fake) {
|
||||
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
}
|
||||
// 4.Set the input of used_node.
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
|
|
|
@ -57,7 +57,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
|
|||
if (item_idx == 0) {
|
||||
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
|
||||
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
|
||||
return nullptr;
|
||||
continue;
|
||||
}
|
||||
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
|
||||
outputs_type.push_back(kNumberTypeFloat16);
|
||||
|
|
Loading…
Reference in New Issue