forked from mindspore-Ecosystem/mindspore
add attr kAttrSkipNopOpAddr for nop node hidden in execution order
This commit is contained in:
parent
a7834c382f
commit
7f251e3f08
|
@ -302,20 +302,20 @@ bool IsAllNopNode(const session::KernelGraph *const graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CheckNopNodeIsOutputNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
|
||||
bool NeedHideNode(const std::vector<AnfNodePtr> &outputs, const AnfNodePtr &node, bool is_dynamic_graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// if node is not a nop node, keep it in execution order
|
||||
if (!IsNopNode(node)) {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
// if node is nop node and the graph is dynamic graph, check if the nop node is graph's output.
|
||||
if (is_dynamic_graph) {
|
||||
auto iter = find(outputs.begin(), outputs.end(), node);
|
||||
if (iter != outputs.end()) {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void HideNopNode(session::KernelGraph *const graph) {
|
||||
|
@ -330,7 +330,9 @@ void HideNopNode(session::KernelGraph *const graph) {
|
|||
std::vector<CNodePtr> new_nodes;
|
||||
for (auto &cnode : execution_order) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
|
||||
if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
} else {
|
||||
new_nodes.push_back(cnode);
|
||||
}
|
||||
}
|
||||
|
@ -352,7 +354,8 @@ void RemoveNopNode(session::KernelGraph *const graph) {
|
|||
for (auto &cnode : graph->execution_order()) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// ignore nop node itself
|
||||
if (!CheckNopNodeIsOutputNode(outputs, cnode, is_dynamic_graph)) {
|
||||
if (NeedHideNode(outputs, cnode, is_dynamic_graph)) {
|
||||
AnfAlgo::SetNodeAttr(kAttrSkipNopOpAddr, MakeValue(true), cnode);
|
||||
continue;
|
||||
}
|
||||
// Replace the input which is nop node
|
||||
|
|
|
@ -501,18 +501,17 @@ uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index)
|
|||
return ptr;
|
||||
}
|
||||
|
||||
session::KernelWithIndex MemReuseUtil::VisitKernelWithReturnType(const AnfNodePtr &node, size_t i,
|
||||
bool visit_nop_node) {
|
||||
session::KernelWithIndex MemReuseUtil::VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool skip_nop_node) {
|
||||
if (!enable_visit_kernel_cache_ || i != 0) {
|
||||
return AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node);
|
||||
return AnfAlgo::VisitKernelWithReturnType(node, i, skip_nop_node);
|
||||
}
|
||||
|
||||
auto &cache =
|
||||
visit_nop_node ? visit_kernel_with_return_type_in0pos_cache_ : visit_kernel_with_return_type_in0pos_skip_nop_cache_;
|
||||
skip_nop_node ? visit_kernel_with_return_type_in0pos_cache_ : visit_kernel_with_return_type_in0pos_skip_nop_cache_;
|
||||
std::unordered_map<AnfNodePtr, session::KernelWithIndex>::iterator tag_iter;
|
||||
if (auto iter = cache.find(node); iter == cache.end()) {
|
||||
auto tmp_item = std::pair<AnfNodePtr, session::KernelWithIndex>{
|
||||
node, AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node)};
|
||||
auto tmp_item =
|
||||
std::pair<AnfNodePtr, session::KernelWithIndex>{node, AnfAlgo::VisitKernelWithReturnType(node, i, skip_nop_node)};
|
||||
tag_iter = cache.emplace(tmp_item).first;
|
||||
} else {
|
||||
tag_iter = iter;
|
||||
|
|
|
@ -87,7 +87,7 @@ class MemReuseUtil {
|
|||
uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
|
||||
uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
|
||||
bool is_all_nop_node() const { return is_all_nop_node_; }
|
||||
session::KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool visit_nop_node);
|
||||
session::KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool skip_nop_node);
|
||||
|
||||
private:
|
||||
KernelRefs kernel_output_refs_;
|
||||
|
|
|
@ -309,7 +309,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
|
|||
if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
|
||||
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
|
||||
}
|
||||
if (opt::IsNopNode(cnode) && (skip_nop_node || IsNeedSkipNopOp(cnode))) {
|
||||
if (opt::IsNopNode(cnode) && skip_nop_node) {
|
||||
if (cnode->size() != kNopNodeInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Invalid nop node " << cnode->DebugString() << " trace: " << trace::DumpSourceLines(cnode);
|
||||
}
|
||||
|
@ -787,18 +787,18 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
|
|||
}
|
||||
|
||||
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
|
||||
<< " trace: " << trace::DumpSourceLines(anf_node);
|
||||
}
|
||||
if (CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
|
||||
return VisitKernelWithReturnType(anf_node, 0, visit_nop_node);
|
||||
return VisitKernelWithReturnType(anf_node, 0, skip_nop_node);
|
||||
}
|
||||
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
|
||||
return VisitKernelWithReturnType(input_node, 0, skip_nop_node);
|
||||
}
|
||||
|
||||
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||
|
@ -1034,10 +1034,9 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &an
|
|||
}
|
||||
|
||||
// get output device addr of anf_node
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx,
|
||||
bool visit_nop_node) {
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node) && (visit_nop_node || IsNeedSkipNopOp(node))) {
|
||||
if (opt::IsNopNode(node) && (skip_nop_node || IsNeedSkipNopOpAddr(node))) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() == kNopNodeInputSize) {
|
||||
|
@ -1059,9 +1058,9 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
|
|||
}
|
||||
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx,
|
||||
bool visit_nop_node) {
|
||||
bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node) && (visit_nop_node || IsNeedSkipNopOp(node))) {
|
||||
if (opt::IsNopNode(node) && (skip_nop_node || IsNeedSkipNopOpAddr(node))) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() == kNopNodeInputSize) {
|
||||
|
@ -1083,14 +1082,14 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
|
|||
}
|
||||
|
||||
// get output device addr of anf_node
|
||||
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node) {
|
||||
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (opt::IsNopNode(node) && (visit_nop_node || IsNeedSkipNopOp(node))) {
|
||||
if (opt::IsNopNode(node) && (skip_nop_node || IsNeedSkipNopOpAddr(node))) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().size() > 1) {
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(cnode, 0);
|
||||
return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
|
||||
return OutputAddrExist(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -1109,15 +1108,15 @@ bool AnfRuntimeAlgorithm::WorkspaceAddrExist(const AnfNodePtr &node, size_t outp
|
|||
}
|
||||
|
||||
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
bool skip_nop_node) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
|
||||
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
||||
}
|
||||
|
||||
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node) {
|
||||
bool skip_nop_node) {
|
||||
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
|
||||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node);
|
||||
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, skip_nop_node);
|
||||
}
|
||||
|
||||
// set output device addr of anf_node
|
||||
|
@ -1687,7 +1686,7 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
|
|||
return kernel_name == kGetNextOpName;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsNeedSkipNopOp(const AnfNodePtr &node) {
|
||||
bool AnfRuntimeAlgorithm::IsNeedSkipNopOpAddr(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
@ -1698,12 +1697,12 @@ bool AnfRuntimeAlgorithm::IsNeedSkipNopOp(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto skip_nop_op_attr = primitive->GetAttr(kAttrSkipNopOp);
|
||||
if (skip_nop_op_attr == nullptr) {
|
||||
auto skip_nop_op_addr_attr = primitive->GetAttr(kAttrSkipNopOpAddr);
|
||||
if (skip_nop_op_addr_attr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return GetValue<bool>(skip_nop_op_attr);
|
||||
return GetValue<bool>(skip_nop_op_addr_attr);
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
|
||||
|
@ -2410,7 +2409,7 @@ void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::Ker
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
|
||||
auto skip_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
|
||||
size_t input_num = GetInputTensorNum(node);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto op_name = GetCNodeName(cnode);
|
||||
|
@ -2426,7 +2425,7 @@ void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::Ker
|
|||
}
|
||||
}
|
||||
auto real_input = GetRealInputIndex(node, i);
|
||||
auto device_address = GetPrevNodeOutputAddr(node, real_input, visit_nop_node);
|
||||
auto device_address = GetPrevNodeOutputAddr(node, real_input, skip_nop_node);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -2436,7 +2435,7 @@ void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::Ker
|
|||
kernel_inputs.emplace_back(input);
|
||||
}
|
||||
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
||||
auto device_address = GetOutputAddr(node, i, visit_nop_node);
|
||||
auto device_address = GetOutputAddr(node, i, skip_nop_node);
|
||||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
output->addr = const_cast<void *>(device_address->GetPtr());
|
||||
|
|
|
@ -150,7 +150,7 @@ class AnfRuntimeAlgorithm {
|
|||
// get input format select of anf node
|
||||
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get prev node output width output index
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false);
|
||||
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool skip_nop_node = false);
|
||||
// get output format from prev node,input_index is the input index of current node related to prev node
|
||||
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
|
||||
// get reshape_type of from the output of input node.
|
||||
|
@ -184,18 +184,18 @@ class AnfRuntimeAlgorithm {
|
|||
// get output select data type from prev node,input_index is the input index of current node related to prev node
|
||||
static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
|
||||
// get output device addr of anf_node
|
||||
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = true);
|
||||
// get mutable output device addr of anf_node
|
||||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true);
|
||||
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = true);
|
||||
// check whether output addr is exist or not
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = false);
|
||||
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx, bool skip_nop_node = false);
|
||||
// check whether workspace addr is exist or not
|
||||
static bool WorkspaceAddrExist(const AnfNodePtr &node, size_t output_idx);
|
||||
// get address from prev node,input_index is the input index of current node related to prev node
|
||||
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx,
|
||||
bool visit_nop_node = true);
|
||||
bool skip_nop_node = true);
|
||||
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx,
|
||||
bool visit_nop_node = true);
|
||||
bool skip_nop_node = true);
|
||||
// set output device addr of anf_node
|
||||
static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
|
||||
// set workspace device addr of anf_node
|
||||
|
@ -271,7 +271,7 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsFusedCommunicationOp(const AnfNodePtr &node);
|
||||
static bool IsInplaceNode(const AnfNodePtr &node, const string &type);
|
||||
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
|
||||
static bool IsNeedSkipNopOp(const AnfNodePtr &node);
|
||||
static bool IsNeedSkipNopOpAddr(const AnfNodePtr &node);
|
||||
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
|
||||
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
|
||||
static bool IsSwitchCall(const CNodePtr &call_node);
|
||||
|
|
|
@ -62,8 +62,8 @@ void DynamicKernel::RebuildDependTensor() {
|
|||
MS_EXCEPTION_IF_NULL(context);
|
||||
for (auto depend : depend_list_) {
|
||||
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend);
|
||||
bool visit_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, visit_nop_node);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend, skip_nop_node);
|
||||
std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
|
||||
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
||||
|
|
|
@ -1339,12 +1339,12 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
|
|||
}
|
||||
}
|
||||
|
||||
DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
|
||||
DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool skip_nop_node) {
|
||||
if (!enable_relation_cache_) {
|
||||
return AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
|
||||
return AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, skip_nop_node);
|
||||
}
|
||||
|
||||
auto &addr_cache = visit_nop_node ? prev_node_mut_output_addr_cache_ : prev_node_mut_output_addr_skip_nop_node_cache_;
|
||||
auto &addr_cache = skip_nop_node ? prev_node_mut_output_addr_cache_ : prev_node_mut_output_addr_skip_nop_node_cache_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<session::KernelWithIndex>>::iterator addr_iter;
|
||||
if (auto iter = addr_cache.find(node); iter == addr_cache.end()) {
|
||||
addr_iter = addr_cache.insert({node, {AnfAlgo::GetInputTensorNum(node), {nullptr, 0}}}).first;
|
||||
|
@ -1353,7 +1353,7 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr
|
|||
}
|
||||
|
||||
if (addr_iter->second[i].first == nullptr) {
|
||||
addr_iter->second[i] = AnfAlgo::GetPrevNodeOutput(node, i, visit_nop_node);
|
||||
addr_iter->second[i] = AnfAlgo::GetPrevNodeOutput(node, i, skip_nop_node);
|
||||
}
|
||||
|
||||
session::KernelWithIndex prev_node_with_index = addr_iter->second[i];
|
||||
|
@ -1364,12 +1364,12 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr
|
|||
return addr;
|
||||
}
|
||||
|
||||
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
|
||||
DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool skip_nop_node) {
|
||||
if (!enable_relation_cache_) {
|
||||
return AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||
return AnfAlgo::GetMutableOutputAddr(node, i, skip_nop_node);
|
||||
}
|
||||
|
||||
auto &addr_cache = visit_nop_node ? mut_output_addr_cache_ : mut_output_addr_skip_nop_node_cache_;
|
||||
auto &addr_cache = skip_nop_node ? mut_output_addr_cache_ : mut_output_addr_skip_nop_node_cache_;
|
||||
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>>::iterator addr_iter;
|
||||
if (auto iter = addr_cache.find(node); iter == addr_cache.end()) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(node);
|
||||
|
@ -1382,12 +1382,12 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node,
|
|||
|
||||
auto &now_addr = addr_iter->second[i];
|
||||
if (now_addr == nullptr) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, skip_nop_node);
|
||||
now_addr = device_address;
|
||||
} else {
|
||||
if (addr_state_.count(now_addr) > 0) {
|
||||
addr_state_.erase(now_addr);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, skip_nop_node);
|
||||
now_addr = device_address;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -110,8 +110,8 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
bool IsDistributedTraining(const session::KernelGraph *graph);
|
||||
void FetchMemUnitSize(const session::KernelGraph *graph);
|
||||
|
||||
DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);
|
||||
DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);
|
||||
DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool skip_nop_node);
|
||||
DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool skip_nop_node);
|
||||
session::KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &node, size_t i);
|
||||
|
||||
void LaunchKernelWithoutMock(const session::KernelGraph *graph, const AnfNodePtr &kernel,
|
||||
|
|
|
@ -1127,7 +1127,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
}
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
|
||||
auto skip_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
|
@ -1143,7 +1143,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
}
|
||||
}
|
||||
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, visit_nop_node);
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, skip_nop_node);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -1154,7 +1154,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, visit_nop_node);
|
||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, skip_nop_node);
|
||||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
output->addr = device_address->ptr_;
|
||||
|
|
|
@ -479,7 +479,7 @@ constexpr auto kAttrProfilingIterEnd = "PROFILING_ITER_END";
|
|||
constexpr auto kAttrHiddenSize = "hidden_size";
|
||||
constexpr auto kAttrInputSize = "input_size";
|
||||
constexpr auto kAttrDstType = "dst_type";
|
||||
constexpr auto kAttrSkipNopOp = "skip_nop_op";
|
||||
constexpr auto kAttrSkipNopOpAddr = "skip_nop_op_addr";
|
||||
|
||||
// primal attr key name
|
||||
constexpr auto kPrimalAttrForwardNodeName = "forward_node_name";
|
||||
|
|
Loading…
Reference in New Issue