forked from mindspore-Ecosystem/mindspore
!11888 Add nontask optimizer for split and splitv.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
1adfd93ff9
|
@ -118,6 +118,7 @@
|
||||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h"
|
||||||
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
#include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h"
|
||||||
#include "backend/optimizer/ascend/enhancer/add_attr_for_3d_graph.h"
|
#include "backend/optimizer/ascend/enhancer/add_attr_for_3d_graph.h"
|
||||||
|
#include "backend/optimizer/ascend/enhancer/split_n_optimizer.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "utils/config_manager.h"
|
#include "utils/config_manager.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
|
@ -366,6 +367,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
|
||||||
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
|
other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>());
|
||||||
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
|
||||||
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
|
||||||
|
other_pm->AddPass(std::make_shared<SplitOpOptimizer>());
|
||||||
optimizer->AddPassManager(other_pm);
|
optimizer->AddPassManager(other_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
kernel_graph->SetExecOrderByDefault();
|
kernel_graph->SetExecOrderByDefault();
|
||||||
|
|
|
@ -0,0 +1,211 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/ascend/enhancer/split_n_optimizer.h"
|
||||||
|
#include <set>
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "base/core_ops.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
namespace {
|
||||||
|
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
|
||||||
|
const std::set<std::string> InvalidOps = {kSplitOpName, kSplitVOpName, kConcatOpName};
|
||||||
|
|
||||||
|
void GetSplitOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *out_nodes) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
if (manager->node_users().find(node) == manager->node_users().end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto &node_index : manager->node_users()[node]) {
|
||||||
|
const AnfNodePtr &output = node_index.first;
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
if (!output->isa<CNode>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto cnode = output->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input0 = cnode->input(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input0);
|
||||||
|
if (IsPrimitive(input0, prim::kPrimMakeTuple) || IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||||
|
GetSplitOutputs(func_graph, output, out_nodes);
|
||||||
|
} else {
|
||||||
|
out_nodes->push_back(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
KernelWithIndex VisitSplitKernel(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (anf_node->isa<ValueNode>()) {
|
||||||
|
return std::make_pair(anf_node, 0);
|
||||||
|
} else if (anf_node->isa<Parameter>()) {
|
||||||
|
return std::make_pair(anf_node, 0);
|
||||||
|
} else if (anf_node->isa<CNode>()) {
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input0 = cnode->input(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input0);
|
||||||
|
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||||
|
auto node = cnode->input(index + IntToSize(1));
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
return VisitSplitKernel(node, 0);
|
||||||
|
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||||
|
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||||
|
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||||
|
}
|
||||||
|
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||||
|
MS_EXCEPTION_IF_NULL(input2);
|
||||||
|
auto value_node = input2->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto item_idx = GetValue<int64_t>(value_node->value());
|
||||||
|
return VisitSplitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
|
||||||
|
} else {
|
||||||
|
return std::make_pair(anf_node, index);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The input is invalid";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool InputCheck(const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto in_nums = AnfAlgo::GetInputTensorNum(node);
|
||||||
|
for (size_t i = 0; i < in_nums; i++) {
|
||||||
|
auto in_node = VisitSplitKernel(AnfAlgo::GetInputNode(cnode, i), 0).first;
|
||||||
|
if (in_node->isa<Parameter>() || in_node->isa<ValueNode>()) {
|
||||||
|
MS_LOG(INFO) << "Input is a Parameter or ValueNode, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (in_node->isa<CNode>()) {
|
||||||
|
auto in_node_name = AnfAlgo::GetCNodeName(in_node->cast<CNodePtr>());
|
||||||
|
auto trans_input = AnfAlgo::VisitKernel(in_node, 0).first;
|
||||||
|
if (in_node_name == kTransDataOpName && (trans_input->isa<Parameter>() || trans_input->isa<ValueNode>())) {
|
||||||
|
MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if ((AnfAlgo::HasNodeAttr("non_task", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "non_task")) ||
|
||||||
|
(AnfAlgo::HasNodeAttr("nop_node", cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, "nop_node"))) {
|
||||||
|
MS_LOG(INFO) << "Input has non_task or nop_node attr, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
std::vector<AnfNodePtr> outputs;
|
||||||
|
GetSplitOutputs(func_graph, node, &outputs);
|
||||||
|
if (outputs.empty()) {
|
||||||
|
MS_LOG(INFO) << "Next node has no outputs, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (const auto &item : outputs) {
|
||||||
|
if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) {
|
||||||
|
MS_LOG(INFO) << "Split has control edge, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::IsRealKernel(item) && (AnfAlgo::GetProcessor(item) != 0)) {
|
||||||
|
MS_LOG(INFO) << "Next node is not a AICore node, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (func_graph->output() == item || AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||||
|
MS_LOG(INFO) << "Next node is graph output or return, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto op_name = AnfAlgo ::GetCNodeName(item);
|
||||||
|
if (InvalidOps.find(op_name) != InvalidOps.end() || AnfAlgo::IsCommunicationOp(node)) {
|
||||||
|
MS_LOG(INFO) << "Next node is " << item->fullname_with_scope() << ", not a invalid node, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!AnfAlgo::GetOutputTensorNum(item)) {
|
||||||
|
MS_LOG(INFO) << "Next node has no output, can not optimizer.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NeedSkip(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
FuncGraphManagerPtr func_manager = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_manager);
|
||||||
|
int64_t split_dim = -1;
|
||||||
|
auto op_name = AnfAlgo::GetCNodeName(node);
|
||||||
|
if (op_name == prim::kPrimSplit->name()) {
|
||||||
|
split_dim = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrAxis);
|
||||||
|
} else if (op_name == prim::kPrimSplitV->name()) {
|
||||||
|
split_dim = AnfAlgo::GetNodeAttr<int64_t>(node, kAttrSplitDim);
|
||||||
|
}
|
||||||
|
if (split_dim != 0) {
|
||||||
|
MS_LOG(INFO) << "Split_dim is not 0, can not optimizer.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
MS_LOG(INFO) << "Split is dynamic shape, can not optimizer.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!(InputCheck(node))) {
|
||||||
|
MS_LOG(INFO) << "Split input check failed, can not optimizer.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!(OutputCheck(func_graph, node))) {
|
||||||
|
MS_LOG(INFO) << "Split output check failed, can not optimizer.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef SplitOpOptimizer::DefinePattern() const {
|
||||||
|
std::shared_ptr<Var> V = std::make_shared<CondVar>(UnVisited);
|
||||||
|
std::shared_ptr<Var> Xs = std::make_shared<SeqVar>();
|
||||||
|
return VectorRef({V, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr SplitOpOptimizer::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
if (!AnfAlgo::IsRealCNodeKernel(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||||
|
auto op_name = AnfAlgo::GetCNodeName(node);
|
||||||
|
if (op_name != prim::kPrimSplit->name() && op_name != prim::kPrimSplitV->name()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!NeedSkip(func_graph, node)) {
|
||||||
|
AnfAlgo::SetNodeAttr("non_task", MakeValue(true), node);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} // namespace mindspore::opt
|
|
@ -0,0 +1,35 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_N_OPTIMIZER_H
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_N_OPTIMIZER_H
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class SplitOpOptimizer : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit SplitOpOptimizer(bool multigraph = true) : PatternProcessPass("split_op_optimizer", multigraph) {}
|
||||||
|
~SplitOpOptimizer() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_SPLIT_N_OPTIMIZER_H
|
|
@ -86,6 +86,7 @@ bool Somas::InitSomasTensors(const session::KernelGraph *graph) {
|
||||||
IndependentNodeOutputProcess(graph);
|
IndependentNodeOutputProcess(graph);
|
||||||
SummaryInputProcess(graph);
|
SummaryInputProcess(graph);
|
||||||
RefNodeProcess(graph);
|
RefNodeProcess(graph);
|
||||||
|
NonTaskSplitProcess(graph);
|
||||||
UnReuseNodeProcess(graph);
|
UnReuseNodeProcess(graph);
|
||||||
GenContiguousList(graph);
|
GenContiguousList(graph);
|
||||||
GetNextOutputProcess(graph);
|
GetNextOutputProcess(graph);
|
||||||
|
@ -535,6 +536,30 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) {
|
||||||
MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size;
|
MS_LOG(INFO) << "Special Tensor total size: RefNode: input " << total_input_size << " output " << total_output_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Somas::NonTaskSplitProcess(const session::KernelGraph *graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto kernel_cnodes = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernel_cnodes) {
|
||||||
|
auto op_name = AnfAlgo::GetCNodeName(kernel);
|
||||||
|
if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, kernel)) {
|
||||||
|
std::vector<size_t> refnode_input_output;
|
||||||
|
auto node = nodes_map_[kernel.get()];
|
||||||
|
if (node->input_tensors_.size() == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " has no input tensor, can not do split non_task process.";
|
||||||
|
}
|
||||||
|
auto input_tensor = node->input_tensors_[0];
|
||||||
|
input_tensor->type_ = kRefNodeInput;
|
||||||
|
refnode_input_output.push_back(input_tensor->GetId());
|
||||||
|
|
||||||
|
for (auto &output_tensor : node->output_tensors_) {
|
||||||
|
output_tensor->type_ = kRefNodeOutput;
|
||||||
|
refnode_input_output.push_back(output_tensor->GetId());
|
||||||
|
}
|
||||||
|
ref_node_constraints_.push_back(refnode_input_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
|
void Somas::UnReuseNodeProcess(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
vector<string> full_name_list = {};
|
vector<string> full_name_list = {};
|
||||||
|
|
|
@ -114,6 +114,7 @@ class Somas {
|
||||||
void IndependentNodeOutputProcess(const session::KernelGraph *graph);
|
void IndependentNodeOutputProcess(const session::KernelGraph *graph);
|
||||||
void SummaryInputProcess(const session::KernelGraph *graph);
|
void SummaryInputProcess(const session::KernelGraph *graph);
|
||||||
void RefNodeProcess(const session::KernelGraph *graph);
|
void RefNodeProcess(const session::KernelGraph *graph);
|
||||||
|
void NonTaskSplitProcess(const session::KernelGraph *graph);
|
||||||
void UnReuseNodeProcess(const session::KernelGraph *graph);
|
void UnReuseNodeProcess(const session::KernelGraph *graph);
|
||||||
SomasTensorPtr CreateGapTensor(size_t gap_tensor_id);
|
SomasTensorPtr CreateGapTensor(size_t gap_tensor_id);
|
||||||
void GenContiguousList(const session::KernelGraph *graph);
|
void GenContiguousList(const session::KernelGraph *graph);
|
||||||
|
|
|
@ -136,7 +136,12 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope());
|
kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope());
|
||||||
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr);
|
auto op_name = AnfAlgo::GetCNodeName(anf_node_ptr);
|
||||||
if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) {
|
if ((op_name == kSplitOpName || op_name == kSplitVOpName) && AnfAlgo::HasNodeAttr(kAttrNonTask, anf_node_ptr)) {
|
||||||
|
MS_LOG(INFO) << "Skip task generation for NonTask op " << anf_node_ptr->fullname_with_scope();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op_name != kAtomicAddrCleanOpName) {
|
||||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) {
|
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) {
|
||||||
if (op_name == kDynamicRNNOpName && i == 3) {
|
if (op_name == kDynamicRNNOpName && i == 3) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -153,6 +158,23 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
|
||||||
AddressPtr input = std::make_shared<Address>();
|
AddressPtr input = std::make_shared<Address>();
|
||||||
input->addr = device_address->ptr_;
|
input->addr = device_address->ptr_;
|
||||||
input->size = device_address->size_;
|
input->size = device_address->size_;
|
||||||
|
|
||||||
|
auto prenode_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
|
||||||
|
if (AnfAlgo::IsRealCNodeKernel(prenode_with_index.first)) {
|
||||||
|
if ((AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitOpName ||
|
||||||
|
AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitVOpName) &&
|
||||||
|
AnfAlgo::HasNodeAttr(kAttrNonTask, prenode_with_index.first->cast<CNodePtr>())) {
|
||||||
|
// use memory offset to implement NonTask Type Split op
|
||||||
|
// when op A -> split(NonTask) -> op B, op B's input addr is split's input0's addr + offset
|
||||||
|
// offset is split's output index * split's output size
|
||||||
|
auto split_input0_device_address = AnfAlgo::GetPrevNodeOutputAddr(prenode_with_index.first, 0);
|
||||||
|
input->addr =
|
||||||
|
static_cast<uint8_t *>(split_input0_device_address->ptr_) + (prenode_with_index.second * input->size);
|
||||||
|
MS_LOG(INFO) << "Change " << anf_node_ptr->fullname_with_scope() << "'s input " << i << " address to "
|
||||||
|
<< split_input0_device_address->ptr_ << " + "
|
||||||
|
<< "prenode_with_index.second * input->size";
|
||||||
|
}
|
||||||
|
}
|
||||||
kernel_inputs.push_back(input);
|
kernel_inputs.push_back(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -379,6 +379,7 @@ constexpr auto kAttrDilation = "dilation";
|
||||||
constexpr auto kAttrPadMode = "pad_mode";
|
constexpr auto kAttrPadMode = "pad_mode";
|
||||||
constexpr auto kAttrPad = "pad";
|
constexpr auto kAttrPad = "pad";
|
||||||
constexpr auto kAttrPadding = "padding";
|
constexpr auto kAttrPadding = "padding";
|
||||||
|
constexpr auto kAttrNonTask = "non_task";
|
||||||
constexpr auto kAttrIsGrad = "is_grad";
|
constexpr auto kAttrIsGrad = "is_grad";
|
||||||
constexpr auto kAttrRecompute = "recompute";
|
constexpr auto kAttrRecompute = "recompute";
|
||||||
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
||||||
|
|
Loading…
Reference in New Issue