From 04136bdf66ec328feaf7ef2f9f18f75c74ff2275 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 26 Nov 2020 21:18:12 +0800 Subject: [PATCH] nop node infer shape --- .../kernel_compiler/rts/memcpy_async.cc | 3 +- .../ccsrc/backend/session/session_basic.cc | 3 +- mindspore/ccsrc/common/trans.cc | 56 +++++++------------ mindspore/ccsrc/common/trans.h | 2 - .../device/ascend/ascend_device_address.cc | 15 ++--- .../device/ascend/ascend_kernel_runtime.cc | 40 +++++++++++++ .../runtime/device/executor/dynamic_kernel.h | 1 + 7 files changed, 72 insertions(+), 48 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index 699a1d61a89..f05c302f93c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -17,6 +17,7 @@ #include "backend/kernel_compiler/rts/memcpy_async.h" #include #include +#include "abstract/utils.h" #include "runtime/mem.h" #include "backend/session/anf_runtime_algorithm.h" #include "common/trans.h" @@ -89,7 +90,7 @@ void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { if (input_size != 1) { MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; } - size_t type_size = trans::TypeIdSize(input_type_id_); + size_t type_size = abstract::TypeIdSize(input_type_id_); std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); size_t total_size = 1; for (size_t i = 0; i < shape_i.size(); i++) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index bac09b1434e..8986c84599d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -20,6 +20,7 @@ #include "c_ops/primitive_c.h" #include "ir/manager.h" +#include "abstract/utils.h" #include "backend/kernel_compiler/common_utils.h" #include "base/core_ops.h" #include "common/trans.h" @@ -1093,7 +1094,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap (void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize); AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, input_node.get()); - size = trans::ShapeSize(shape_tmp) * trans::TypeIdSize(tensor->data_type()); + size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); } if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index be2772ae1a9..b8729190e6c 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -18,6 +18,7 @@ #include #include #include "utils/ms_utils.h" +#include "abstract/utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/kernel.h" #include "runtime/device/convert_tensor_utils.h" @@ -28,12 +29,6 @@ namespace mindspore { namespace trans { enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; -const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, - {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, - {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, - {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, - {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; - inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { switch (size) { case 1: @@ -117,8 +112,8 @@ const std::map, DataTypeTransMode> mode_map{ {std::pair(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}}; void CheckMemSize(const TypeIdArgs &args) { - auto src_type_size = TypeIdSize(args.host_data_type); - auto dst_type_size = TypeIdSize(args.device_data_type); + auto src_type_size = abstract::TypeIdSize(args.host_data_type); + auto dst_type_size = abstract::TypeIdSize(args.device_data_type); if (src_type_size < 1 || dst_type_size < 1) { MS_LOG(EXCEPTION) << "Invalid src or dst data type."; } @@ -192,7 +187,7 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const size_t CubeSizeByType(const TypeId data_type) { const size_t default_error = 0; - auto dt_size = TypeIdSize(data_type); + auto dt_size = abstract::TypeIdSize(data_type); if (dt_size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return default_error; @@ -202,19 +197,6 @@ size_t CubeSizeByType(const TypeId data_type) { return kCubeSize; } -size_t ShapeSize(const std::vector &shape) { - return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies()); -} - -size_t TypeIdSize(const TypeId data_type) { - const size_t unsupported_type_error = 0; - auto iter = type_map.find(data_type); - if (iter != type_map.end()) { - return iter->second; - } - return unsupported_type_error; -} - namespace { bool CheckDims(const std::vector &shape) { if (shape.size() != kNchwDims) { @@ -477,12 +459,12 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { } MS_EXCEPTION_IF_NULL(size); MS_EXCEPTION_IF_NULL(total_size); - *size = TypeIdSize(args.src_data_type); + *size = abstract::TypeIdSize(args.src_data_type); if (*size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - *total_size = ShapeSize(args.device_shape) * (*size); + *total_size = abstract::ShapeSize(args.device_shape) * (*size); if (*total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; return false; @@ -516,7 +498,7 @@ bool TransFormat(const FormatArgs &args, void *result) { {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; MS_LOG(DEBUG) << "Start trans format."; - if (TypeIdSize(args.src_data_type) < 1) { + if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } @@ -538,7 +520,7 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; MS_LOG(DEBUG) << "Start trans format."; - if (TypeIdSize(args.src_data_type) < 1) { + if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } @@ -624,7 +606,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; @@ -685,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - auto total_size = ShapeSize(args.device_shape) * size; + auto total_size = abstract::ShapeSize(args.device_shape) * size; if (total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; return false; @@ -828,13 +810,13 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid shape size."; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype"; return false; } - auto dst_size = ShapeSize(args.device_shape) * size; + auto dst_size = abstract::ShapeSize(args.device_shape) * size; if (dst_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; return false; @@ -890,13 +872,13 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid shape size."; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype"; return false; } - auto dst_size = ShapeSize(args.device_shape) * size; + auto dst_size = abstract::ShapeSize(args.device_shape) * size; if (dst_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size; return false; @@ -947,12 +929,12 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - auto total_size = ShapeSize(args.device_shape) * size; + auto total_size = abstract::ShapeSize(args.device_shape) * size; if (total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; return false; @@ -1005,12 +987,12 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - auto size = TypeIdSize(args.src_data_type); + auto size = abstract::TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - auto total_size = ShapeSize(args.device_shape) * size; + auto total_size = abstract::ShapeSize(args.device_shape) * size; if (total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; return false; diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 702dda04507..cad059eaa26 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -48,8 +48,6 @@ struct FormatArgs { TypeId src_data_type; }; -size_t TypeIdSize(const TypeId data_type); -size_t ShapeSize(const std::vector &shape); size_t CubeSizeByType(const TypeId data_type); std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis = {}); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 8a08c796c9f..ec1c1cacf64 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -26,6 +26,7 @@ #include "runtime/device/convert_tensor_utils.h" #include "ir/dtype/type.h" #include "ir/tensor.h" +#include "abstract/utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "utils/utils.h" @@ -298,7 +299,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); } else { - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); auto host = std::vector(size_); SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_}; @@ -413,11 +414,11 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const MS_EXCEPTION_IF_NULL(kernel_mod_ptr); auto host_size = size; if (type_id_ != type) { - auto device_dtype_size = trans::TypeIdSize(type_id_); + auto device_dtype_size = abstract::TypeIdSize(type_id_); if (device_dtype_size < 1) { MS_LOG(ERROR) << "Illegal dtype."; } - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); size = device_dtype_size * shape_size; } size = GetCommonAlignSize(size); @@ -431,7 +432,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const } else { auto host = std::vector(size); SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST); - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { @@ -500,7 +501,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh MS_LOG(ERROR) << "Trans format failed."; return false; } - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { @@ -537,7 +538,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); } else { - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); @@ -581,7 +582,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh device_shape = trans::TransShapeToDevice(host_shape, format_); } if (type_id_ != type) { - auto shape_size = trans::ShapeSize(host_shape); + auto shape_size = abstract::ShapeSize(host_shape); const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 79cdeb5c4b7..bea01fd8a17 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -22,6 +22,8 @@ #include #include #include +#include +#include "abstract/primitive_infer_map.h" #include "debug/data_dump/e2e_dump_util.h" #include "runtime/device/ascend/ascend_device_address.h" #include "runtime/device/cpu/mpi/mpi_interface.h" @@ -39,6 +41,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/optimizer/common/helper.h" #include "runtime/device/ascend/ascend_memory_manager.h" #include "debug/tensor_load.h" #include "debug/data_dump/dump_json_parser.h" @@ -110,6 +113,34 @@ std::string GetRankId() { } return rank_id_str; } + +void InferShapeForNopNode(AnfNodePtr *input_node) { + MS_EXCEPTION_IF_NULL(*input_node); + if (!opt::IsNopNode(*input_node)) { + MS_LOG(INFO) << "Input node is not a nop node, no need infer."; + return; + } + MS_LOG(INFO) << "Infer shape for nop node."; + std::stack nop_road; + nop_road.push(*input_node); + + while (true) { + auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0); + auto in_node = input_node_with_idx.first; + MS_EXCEPTION_IF_NULL(in_node); + if (opt::IsNopNode(in_node)) { + nop_road.push(in_node); + *input_node = in_node; + } else { + break; + } + } + while (!nop_road.empty()) { + auto nop_node = nop_road.top(); + AnfAlgo::InferShape(nop_node->cast()); + nop_road.pop(); + } +} } // namespace std::vector AscendKernelRuntime::exception_infoes_; @@ -633,6 +664,15 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap } if (dynamic_kernel->is_dynamic_shape()) { + auto kernel_node = dynamic_kernel->kernel_node(); + MS_EXCEPTION_IF_NULL(kernel_node); + auto input_size = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_size; i++) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i); + auto input_node = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(input_node); + InferShapeForNopNode(&input_node); + } dynamic_kernel->InferShape(); dynamic_kernel->UpdateArgs(); } diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h index 6f977751918..43a438b9feb 100644 --- a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h @@ -48,6 +48,7 @@ class DynamicKernel { virtual void Initialize(); std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } int GetKernelType(); + CNodePtr kernel_node() const { return cnode_ptr_; } protected: void RebuildDependTensor();