forked from mindspore-Ecosystem/mindspore
add io format attr for 3d graph
This commit is contained in:
parent
fd36e0911e
commit
f5af5da364
|
@ -64,6 +64,7 @@
|
|||
#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/add_attr_for_3d_graph.h"
|
||||
#include "backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
|
||||
|
@ -224,6 +225,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
|
||||
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
|
||||
data_layout_pm->AddPass(std::make_shared<DynamicRNNGradReformat>());
|
||||
data_layout_pm->AddPass(std::make_shared<AddIoFormatAttrFor3DGraph>());
|
||||
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/ascend/format_type/add_attr_for_3d_graph.h"
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void AddAttrForAllCNode(const std::vector<AnfNodePtr> &node_list) {
|
||||
for (auto node : node_list) {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr("io_format", MakeValue(kOpFormat_NCDHW), node);
|
||||
}
|
||||
}
|
||||
|
||||
bool NodeHasAttrIoFormat(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::HasNodeAttr("io_format", cnode)) {
|
||||
auto attr = AnfAlgo::GetNodeAttr<std::string>(cnode, "io_format");
|
||||
return attr == kOpFormat_NCDHW;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool AddIoFormatAttrFor3DGraph::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
|
||||
bool changed = false;
|
||||
if (std::any_of(node_list.begin(), node_list.end(),
|
||||
[](const AnfNodePtr &node) { return NodeHasAttrIoFormat(node); })) {
|
||||
AddAttrForAllCNode(node_list);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AddIoFormatAttrFor3DGraph : public Pass {
|
||||
public:
|
||||
explicit AddIoFormatAttrFor3DGraph(size_t groups = 1) : Pass("add_attr_for_3d_graph"), groups_(groups) {}
|
||||
~AddIoFormatAttrFor3DGraph() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
size_t groups_ = 1;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_ADD_ATTR_FOR_3D_GRAPH_H
|
Loading…
Reference in New Issue