rename real input index to input index in graph

This commit is contained in:
LaiYongqiang 2022-07-08 14:59:49 +08:00
parent 4a81173b1b
commit d517d9f8d0
10 changed files with 34 additions and 32 deletions

View File

@ -54,14 +54,14 @@ constexpr size_t kReturnDataIndex = 1;
constexpr size_t kSwitchTrueBranchIndex = 2;
// ops pair that dynamic input order is differ from the fixed shape ops
// pair: <real_input->ori_input, ori_input->real_input>
// pair: <input_index_in_kernel->input_index_in_graph, input_index_in_graph->input_index_in_kernel>
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_dynamic_node_list = {
{prim::kPrimStridedSliceGrad->name(),
{{{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 0}}, {{1, 0}, {2, 1}, {3, 2}, {4, 3}, {0, 4}}}},
{prim::kPrimConv2DBackpropInput->name(), {{{0, 2}, {1, 1}, {2, 0}}, {{0, 2}, {1, 1}, {2, 0}}}},
{prim::kPrimConv2DBackpropFilter->name(), {{{0, 1}, {1, 2}, {2, 0}}, {{1, 0}, {2, 1}, {0, 2}}}}};
// pair: <real_input->ori_input, ori_input->real_input>
// pair: <input_index_in_kernel->input_index_in_graph, input_index_in_graph->input_index_in_kernel>
static std::map<std::string, std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>>> spec_node_list = {
{prim::kPrimConv2DBackpropInput->name(), {{{0, 1}, {1, 0}}, {{0, 1}, {1, 0}}}},
{kFusionOpConv2DBackpropInputReluGradV2Name, {{{0, 1}, {1, 0}, {2, 2}}, {{0, 1}, {1, 0}, {2, 2}}}},
@ -824,16 +824,17 @@ bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input
return IsFeatureMapOutput(input_node);
}
size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
size_t AnfRuntimeAlgorithm::GetInputIndexInGraph(const mindspore::AnfNodePtr &anf_node,
const size_t input_index_in_kernel) {
MS_EXCEPTION_IF_NULL(anf_node);
size_t ret = cur_index;
size_t ret = input_index_in_kernel;
auto node_name = common::AnfAlgo::GetCNodeName(anf_node);
if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
if (common::AnfAlgo::IsDynamicShape(anf_node)) {
auto find_dynamic = spec_dynamic_node_list.find(node_name);
if (find_dynamic != spec_dynamic_node_list.cend()) {
auto dyn_index_converter = find_dynamic->second;
ret = dyn_index_converter.first[cur_index];
ret = dyn_index_converter.first[input_index_in_kernel];
MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
return ret;
}
@ -841,7 +842,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
if (op_info != nullptr) {
auto real_input_index = op_info->real_input_index();
if (!real_input_index.first.empty()) {
ret = real_input_index.first[cur_index];
ret = real_input_index.first[input_index_in_kernel];
return ret;
}
}
@ -849,7 +850,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
auto find = spec_node_list.find(node_name);
if (find != spec_node_list.cend()) {
auto index_converter = find->second;
ret = index_converter.first[cur_index];
ret = index_converter.first[input_index_in_kernel];
MS_LOG(DEBUG) << "Real input index change to " << ret << ", node name:" << node_name;
return ret;
}
@ -857,7 +858,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
if (op_info != nullptr) {
auto real_input_index = op_info->real_input_index();
if (!real_input_index.first.empty()) {
ret = real_input_index.first[cur_index];
ret = real_input_index.first[input_index_in_kernel];
return ret;
}
}
@ -865,16 +866,17 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
return ret;
}
size_t AnfRuntimeAlgorithm::GetOriginalInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
size_t AnfRuntimeAlgorithm::GetInputIndexInKernel(const mindspore::AnfNodePtr &anf_node,
const size_t input_index_in_graph) {
MS_EXCEPTION_IF_NULL(anf_node);
size_t ret = cur_index;
size_t ret = input_index_in_graph;
auto node_name = common::AnfAlgo::GetCNodeName(anf_node);
if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
if (common::AnfAlgo::IsDynamicShape(anf_node)) {
auto find_dynamic = spec_dynamic_node_list.find(node_name);
if (find_dynamic != spec_dynamic_node_list.cend()) {
auto dyn_index_converter = find_dynamic->second;
ret = dyn_index_converter.second[cur_index];
ret = dyn_index_converter.second[input_index_in_graph];
MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
return ret;
}
@ -882,7 +884,7 @@ size_t AnfRuntimeAlgorithm::GetOriginalInputIndex(const mindspore::AnfNodePtr &a
if (op_info != nullptr) {
auto real_input_index = op_info->real_input_index();
if (!real_input_index.second.empty()) {
ret = real_input_index.second[cur_index];
ret = real_input_index.second[input_index_in_graph];
return ret;
}
}
@ -890,14 +892,14 @@ size_t AnfRuntimeAlgorithm::GetOriginalInputIndex(const mindspore::AnfNodePtr &a
auto find = spec_node_list.find(node_name);
if (find != spec_node_list.cend()) {
auto index_converter = find->second;
ret = index_converter.second[cur_index];
ret = index_converter.second[input_index_in_graph];
MS_LOG(DEBUG) << "Get original input index " << ret << ", node name:" << node_name;
}
auto op_info = kernel::OpLib::FindOp(node_name, kernel::kTBE);
if (op_info != nullptr) {
auto real_input_index = op_info->real_input_index();
if (!real_input_index.second.empty()) {
ret = real_input_index.second[cur_index];
ret = real_input_index.second[input_index_in_graph];
return ret;
}
}
@ -1189,7 +1191,7 @@ void AnfRuntimeAlgorithm::CacheAddrForGraph(const KernelGraphPtr &kernel_graph)
// And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
if (common::AnfAlgo::HasNodeAttr(kAttrNopOp, kernel)) {
for (size_t idx = 0; idx < common::AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
auto real_input = GetRealInputIndex(kernel, idx);
auto real_input = GetInputIndexInGraph(kernel, idx);
auto device_address = GetPrevNodeMutableOutputAddr(kernel, real_input);
SetOutputAddr(device_address, idx, kernel.get());
}
@ -1221,7 +1223,7 @@ void AnfRuntimeAlgorithm::CacheAddrForKernel(const AnfNodePtr &node, kernel::Ker
if (common::AnfAlgo::IsNoneInput(node, i)) {
continue;
}
auto real_input = GetRealInputIndex(node, i);
auto real_input = GetInputIndexInGraph(node, i);
auto device_address = GetPrevNodeOutputAddr(node, real_input, skip_nop_node);
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();

View File

@ -153,10 +153,10 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
static bool IsFeatureMapOutput(const AnfNodePtr &node);
// charge if the node's input is from a feature map output
static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
// get me input index for some tbe ops which input order is different between me and tbe impl
static size_t GetOriginalInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
// get input index in graph for some tbe ops which input order is different between graph and tbe kernel
static size_t GetInputIndexInGraph(const AnfNodePtr &anf_node, const size_t input_index_in_kernel);
// get input index in kernel for some tbe ops which input order is different between graph and tbe kernel
static size_t GetInputIndexInKernel(const AnfNodePtr &anf_node, const size_t input_index_in_graph);
static std::vector<KernelGraphPtr> GetCallSwitchKernelGraph(const CNodePtr &cnode);
static bool IsIndependentNode(const CNodePtr &node);
static void InferShape(const CNodePtr &node, std::map<uint32_t, tensor::TensorPtr> *depend_tensors = nullptr);

View File

@ -379,7 +379,7 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
<< " is larger than input number: " << common::AnfAlgo::GetInputTensorNum(node)
<< trace::DumpSourceLines(node);
}
auto real_input_index = AnfAlgo::GetRealInputIndex(node, i);
auto real_input_index = AnfAlgo::GetInputIndexInGraph(node, i);
auto input = node->input(real_input_index + 1);
MS_EXCEPTION_IF_NULL(input);
auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);

View File

@ -1057,7 +1057,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph &graph) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_num; ++i) {
auto real_input_index = AnfAlgo::GetRealInputIndex(node, i);
auto real_input_index = AnfAlgo::GetInputIndexInGraph(node, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, real_input_index);
MS_LOG(INFO) << "Input idx " << i << " size " << device_address->size_ << " addr " << device_address->ptr_;
int32_t value = 0;

View File

@ -473,7 +473,7 @@ void DataDumper::DumpKernelInput(const CNodePtr &kernel, void *args, NotNull<aic
auto input_size = common::AnfAlgo::GetInputTensorNum(kernel);
uint64_t offset = 0;
for (size_t i = 0; i < input_size; ++i) {
auto real_index = AnfAlgo::GetRealInputIndex(kernel, i);
auto real_index = AnfAlgo::GetInputIndexInGraph(kernel, i);
if (common::AnfAlgo::IsNoneInput(kernel, real_index)) {
continue;
}

View File

@ -191,7 +191,7 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
if (common::AnfAlgo::IsNoneInput(anf_node_ptr, i)) {
continue;
}
auto input_index_in_graph = AnfAlgo::GetRealInputIndex(anf_node_ptr, i);
auto input_index_in_graph = AnfAlgo::GetInputIndexInGraph(anf_node_ptr, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, input_index_in_graph);
AddressPtr input = std::make_shared<Address>();
MS_EXCEPTION_IF_NULL(input);

View File

@ -364,7 +364,7 @@ bool AscendKernelExecutor::GetKernelRealInputs(const CNodePtr &kernel, const vec
}
for (size_t i = 0; i < input_num; ++i) {
auto real_index = AnfAlgo::GetRealInputIndex(kernel, i);
auto real_index = AnfAlgo::GetInputIndexInGraph(kernel, i);
if (real_index >= input_num) {
MS_LOG(ERROR) << "Total input num is " << input_num << " but get real_index " << real_index;
return false;

View File

@ -79,7 +79,7 @@ void AssignOutputNopNodeDeviceAddress(const KernelGraphPtr &graph, const device:
continue;
}
auto real_input_index = AnfAlgo::GetRealInputIndex(output, 0);
auto real_input_index = AnfAlgo::GetInputIndexInGraph(output, 0);
auto pre_node_out_device_address = AnfAlgo::GetPrevNodeOutputAddr(output, real_input_index);
MS_EXCEPTION_IF_NULL(pre_node_out_device_address);
auto ptr = pre_node_out_device_address->GetPtr();

View File

@ -107,7 +107,7 @@ void OpTilingCalculateAdapter::ConvertInputShapeAndType(const CNodePtr &node, ::
auto input_size = common::AnfAlgo::GetInputTensorNum(node);
for (size_t i = 0; i < input_size; i++) {
// ms info
auto real_index = AnfAlgo::GetRealInputIndex(node, i);
auto real_index = AnfAlgo::GetInputIndexInGraph(node, i);
auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(node, real_index);
auto input_node = input_node_with_idx.first;
auto input_index = input_node_with_idx.second;
@ -311,7 +311,7 @@ std::vector<std::tuple<std::size_t, ::ge::NodePtr>> OpTilingCalculateAdapter::Co
auto depend_name = input_names_attr[index];
auto const_tensor = iter->second;
::ge::NodePtr ge_constant_op = NewConstantOp(node, depend_name, const_tensor, ge_graph, index);
auto original_index = AnfAlgo::GetOriginalInputIndex(node, index);
auto original_index = AnfAlgo::GetInputIndexInKernel(node, index);
constant_ops.emplace_back(std::tuple<std::size_t, ::ge::NodePtr>(original_index, ge_constant_op));
op_infer_depends.emplace_back(depend_name);
}

View File

@ -326,7 +326,7 @@ void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_mod);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
for (size_t j = 0; j < input_num; ++j) {
auto input_index = AnfAlgo::GetRealInputIndex(kernel, j);
auto input_index = AnfAlgo::GetInputIndexInGraph(kernel, j);
KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
auto index = kernel_with_index.second;
auto &input_node = kernel_with_index.first;
@ -1258,7 +1258,7 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
if (common::AnfAlgo::IsNoneInput(kernel, i)) {
continue;
}
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto real_input = AnfAlgo::GetInputIndexInGraph(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, skip_nop_node);
MS_EXCEPTION_IF_NULL(device_address);
kernel::AddressPtr input = std::make_shared<kernel::Address>();
@ -1487,7 +1487,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
const auto update_parameter = common::AnfAlgo::IsUpdateParameterKernel(cnode);
for (size_t j = 0; j < input_num; ++j) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
auto real_input = AnfAlgo::GetInputIndexInGraph(kernel, j);
auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
auto index = kernel_with_index.second;
auto &input_node = kernel_with_index.first;
@ -1779,7 +1779,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
// And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
if (common::AnfAlgo::HasNodeAttr(kAttrNopOp, kernel)) {
for (size_t idx = 0; idx < common::AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, idx);
auto real_input = AnfAlgo::GetInputIndexInGraph(kernel, idx);
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
}