fix transdata's dst format && src format is unmatched with build info when transdata has been spilted

This commit is contained in:
LianLiguang 2021-03-31 15:54:30 +08:00
parent 36dbb2690e
commit 9c8d016d66
3 changed files with 41 additions and 31 deletions

View File

@ -99,7 +99,7 @@
#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
@ -254,7 +254,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<DealRefAndSpiltUnSupportedTransdata>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
#include <utility>
#include <vector>
#include <memory>
@ -26,7 +26,7 @@
namespace mindspore {
namespace opt {
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const {
session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
AnfNodePtr cur_node = kernel_with_index.first;
size_t cur_out_index = kernel_with_index.second;
@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr
return kernel_with_index;
}
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const size_t output_index, const size_t input_index) const {
void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
const CNodePtr &cnode, const size_t output_index,
const size_t input_index) const {
// record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -71,10 +72,10 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
}
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const AnfNodePtr &get_item, const AnfNodePtr &final_node,
size_t final_index,
const session::KernelWithIndex &origin_pair) const {
void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const AnfNodePtr &get_item,
const AnfNodePtr &final_node, size_t final_index,
const session::KernelWithIndex &origin_pair) const {
// record the ref_pair
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -95,9 +96,10 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph
// if get_item is nullptr, the additional node will link to the cnode
// else the additional node will link to the get_item node (the get_item node link to cnode)
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
size_t output_index, size_t input_index,
const CNodePtr &get_item) const {
CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
const CNodePtr &cnode, size_t output_index,
size_t input_index,
const CNodePtr &get_item) const {
CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
bool need_refresh_ref_addr = false;
size_t final_index = output_index;
@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
return final_node;
}
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const {
CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
const CNodePtr &cnode,
const FuncGraphPtr &func_graph) const {
std::vector<AnfNodePtr> depend_nodes;
if (get_item != nullptr) {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo
}
return func_graph->NewCNode(depend_nodes);
}
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
std::vector<AnfNodePtr> make_tuple_inputs;
@ -185,8 +188,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
return make_tuple;
}
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(op_info);
auto ref_infos = op_info->ref_infos();
@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph,
return nullptr;
}
const BaseRef DealRefTransAndCast::DefinePattern() const {
const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; ++i) {
@ -219,8 +223,8 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con
}
}
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
return nullptr;
}
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
const CNodePtr &cnode) const {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
MS_EXCEPTION_IF_NULL(kernel_info);
// When the input and output format is only one special format just need to be splited into transpose and transdata
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
if (IsFormatInvaild(cnode)) {
@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
}
return cnode;
}
// When input and output format are all special format
// the node should be splited to two transdata connected by default format
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
next_trans_node->set_abstract(cnode->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
if (IsFormatInvaild(cnode)) {
auto after_split_node = DoSplit(func_graph, cnode);
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
@ -25,10 +25,11 @@
namespace mindspore {
namespace opt {
class DealRefTransAndCast : public TransDataSplit {
class DealRefAndSpiltUnSupportedTransdata : public TransDataSplit {
public:
explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {}
~DealRefTransAndCast() override = default;
explicit DealRefAndSpiltUnSupportedTransdata(bool multigraph = true)
: TransDataSplit(multigraph, "deal_ref_and_transdata_spilt") {}
~DealRefAndSpiltUnSupportedTransdata() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
@ -52,4 +53,4 @@ class DealRefTransAndCast : public TransDataSplit {
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_