forked from OSSInnovation/mindspore
fix transdata's dst format && src format is unmatched with build info when transdata has been spilted
This commit is contained in:
parent
36dbb2690e
commit
9c8d016d66
|
@ -99,7 +99,7 @@
|
|||
#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
|
||||
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
|
||||
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h"
|
||||
#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h"
|
||||
|
@ -254,7 +254,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
|
|||
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<DealRefAndSpiltUnSupportedTransdata>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h"
|
||||
#include "backend/optimizer/ascend/format_type/deal_ref_and_split_unsupported_transdata.h"
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -26,7 +26,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const {
|
||||
session::KernelWithIndex DealRefAndSpiltUnSupportedTransdata::FindRefOriginNode(const AnfNodePtr &node) const {
|
||||
session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0);
|
||||
AnfNodePtr cur_node = kernel_with_index.first;
|
||||
size_t cur_out_index = kernel_with_index.second;
|
||||
|
@ -61,8 +61,9 @@ session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr
|
|||
return kernel_with_index;
|
||||
}
|
||||
|
||||
void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const size_t output_index, const size_t input_index) const {
|
||||
void DealRefAndSpiltUnSupportedTransdata::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode, const size_t output_index,
|
||||
const size_t input_index) const {
|
||||
// record the ref_pair
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -71,10 +72,10 @@ void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_g
|
|||
kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index);
|
||||
}
|
||||
|
||||
void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const AnfNodePtr &get_item, const AnfNodePtr &final_node,
|
||||
size_t final_index,
|
||||
const session::KernelWithIndex &origin_pair) const {
|
||||
void DealRefAndSpiltUnSupportedTransdata::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const AnfNodePtr &get_item,
|
||||
const AnfNodePtr &final_node, size_t final_index,
|
||||
const session::KernelWithIndex &origin_pair) const {
|
||||
// record the ref_pair
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -95,9 +96,10 @@ void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph
|
|||
|
||||
// if get_item is nullptr, the additional node will link to the cnode
|
||||
// else the additional node will link to the get_item node (the get_item node link to cnode)
|
||||
CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
size_t output_index, size_t input_index,
|
||||
const CNodePtr &get_item) const {
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode, size_t output_index,
|
||||
size_t input_index,
|
||||
const CNodePtr &get_item) const {
|
||||
CNodePtr final_node = (get_item == nullptr ? cnode : get_item);
|
||||
bool need_refresh_ref_addr = false;
|
||||
size_t final_index = output_index;
|
||||
|
@ -149,8 +151,9 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
|
|||
return final_node;
|
||||
}
|
||||
|
||||
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
|
||||
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const {
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
|
||||
const CNodePtr &cnode,
|
||||
const FuncGraphPtr &func_graph) const {
|
||||
std::vector<AnfNodePtr> depend_nodes;
|
||||
if (get_item != nullptr) {
|
||||
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
|
||||
|
@ -159,8 +162,8 @@ CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNo
|
|||
}
|
||||
return func_graph->NewCNode(depend_nodes);
|
||||
}
|
||||
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefForMultipleOutput(
|
||||
const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto ref_infos = op_info->ref_infos();
|
||||
std::vector<AnfNodePtr> make_tuple_inputs;
|
||||
|
@ -185,8 +188,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto ref_infos = op_info->ref_infos();
|
||||
|
@ -200,13 +203,14 @@ CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
const BaseRef DealRefTransAndCast::DefinePattern() const {
|
||||
const BaseRef DealRefAndSpiltUnSupportedTransdata::DefinePattern() const {
|
||||
VarPtr V = std::make_shared<CondVar>(UnVisited);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({V, Xs});
|
||||
}
|
||||
|
||||
void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const {
|
||||
void DealRefAndSpiltUnSupportedTransdata::DealBroadCastAsRef(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode) const {
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) {
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_size; ++i) {
|
||||
|
@ -219,8 +223,8 @@ void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, con
|
|||
}
|
||||
}
|
||||
|
||||
const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -250,11 +254,12 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode) const {
|
||||
CNodePtr DealRefAndSpiltUnSupportedTransdata::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph,
|
||||
const CNodePtr &cnode) const {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
// When the input and output format is only one special format just need to be splited into transpose and transdata
|
||||
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
|
||||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
|
||||
if (IsFormatInvaild(cnode)) {
|
||||
|
@ -262,6 +267,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
|
|||
}
|
||||
return cnode;
|
||||
}
|
||||
// When input and output format are all special format
|
||||
// the node should be splited to two transdata connected by default format
|
||||
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
|
||||
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
|
@ -273,6 +280,8 @@ CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &f
|
|||
next_trans_node->set_abstract(cnode->abstract());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
|
||||
RefreshKernelBuildInfo(AnfAlgo::GetInputFormat(cnode, 0), kOpFormat_DEFAULT, cnode);
|
||||
RefreshKernelBuildInfo(kOpFormat_DEFAULT, AnfAlgo::GetOutputFormat(next_trans_node, 0), next_trans_node);
|
||||
if (IsFormatInvaild(cnode)) {
|
||||
auto after_split_node = DoSplit(func_graph, cnode);
|
||||
AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0);
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
@ -25,10 +25,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DealRefTransAndCast : public TransDataSplit {
|
||||
class DealRefAndSpiltUnSupportedTransdata : public TransDataSplit {
|
||||
public:
|
||||
explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {}
|
||||
~DealRefTransAndCast() override = default;
|
||||
explicit DealRefAndSpiltUnSupportedTransdata(bool multigraph = true)
|
||||
: TransDataSplit(multigraph, "deal_ref_and_transdata_spilt") {}
|
||||
~DealRefAndSpiltUnSupportedTransdata() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
|
@ -52,4 +53,4 @@ class DealRefTransAndCast : public TransDataSplit {
|
|||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_AND_SPLIT_UNSUPPORTED_TRANSADATA_H_
|
Loading…
Reference in New Issue