nop node infer shape

This commit is contained in:
liubuyu 2020-11-26 21:18:12 +08:00
parent dabb82ec7a
commit 04136bdf66
7 changed files with 72 additions and 48 deletions

View File

@ -17,6 +17,7 @@
#include "backend/kernel_compiler/rts/memcpy_async.h"
#include <memory>
#include <string>
#include "abstract/utils.h"
#include "runtime/mem.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "common/trans.h"
@ -89,7 +90,7 @@ void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
if (input_size != 1) {
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
}
size_t type_size = trans::TypeIdSize(input_type_id_);
size_t type_size = abstract::TypeIdSize(input_type_id_);
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0);
size_t total_size = 1;
for (size_t i = 0; i < shape_i.size(); i++) {

View File

@ -20,6 +20,7 @@
#include "c_ops/primitive_c.h"
#include "ir/manager.h"
#include "abstract/utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "base/core_ops.h"
#include "common/trans.h"
@ -1093,7 +1094,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
input_node.get());
size = trans::ShapeSize(shape_tmp) * trans::TypeIdSize(tensor->data_type());
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
}
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);

View File

@ -18,6 +18,7 @@
#include <numeric>
#include <utility>
#include "utils/ms_utils.h"
#include "abstract/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/kernel.h"
#include "runtime/device/convert_tensor_utils.h"
@ -28,12 +29,6 @@
namespace mindspore {
namespace trans {
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
{kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
switch (size) {
case 1:
@ -117,8 +112,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
void CheckMemSize(const TypeIdArgs &args) {
auto src_type_size = TypeIdSize(args.host_data_type);
auto dst_type_size = TypeIdSize(args.device_data_type);
auto src_type_size = abstract::TypeIdSize(args.host_data_type);
auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
if (src_type_size < 1 || dst_type_size < 1) {
MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
}
@ -192,7 +187,7 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const
size_t CubeSizeByType(const TypeId data_type) {
const size_t default_error = 0;
auto dt_size = TypeIdSize(data_type);
auto dt_size = abstract::TypeIdSize(data_type);
if (dt_size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return default_error;
@ -202,19 +197,6 @@ size_t CubeSizeByType(const TypeId data_type) {
return kCubeSize;
}
size_t ShapeSize(const std::vector<size_t> &shape) {
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
}
size_t TypeIdSize(const TypeId data_type) {
const size_t unsupported_type_error = 0;
auto iter = type_map.find(data_type);
if (iter != type_map.end()) {
return iter->second;
}
return unsupported_type_error;
}
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != kNchwDims) {
@ -477,12 +459,12 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
}
MS_EXCEPTION_IF_NULL(size);
MS_EXCEPTION_IF_NULL(total_size);
*size = TypeIdSize(args.src_data_type);
*size = abstract::TypeIdSize(args.src_data_type);
if (*size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
*total_size = ShapeSize(args.device_shape) * (*size);
*total_size = abstract::ShapeSize(args.device_shape) * (*size);
if (*total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
return false;
@ -516,7 +498,7 @@ bool TransFormat(const FormatArgs &args, void *result) {
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
MS_LOG(DEBUG) << "Start trans format.";
if (TypeIdSize(args.src_data_type) < 1) {
if (abstract::TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype..";
return false;
}
@ -538,7 +520,7 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
MS_LOG(DEBUG) << "Start trans format.";
if (TypeIdSize(args.src_data_type) < 1) {
if (abstract::TypeIdSize(args.src_data_type) < 1) {
MS_LOG(ERROR) << "Invalid datatype..";
return false;
}
@ -624,7 +606,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
@ -685,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = ShapeSize(args.device_shape) * size;
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
@ -828,13 +810,13 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid shape size.";
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype";
return false;
}
auto dst_size = ShapeSize(args.device_shape) * size;
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
if (dst_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
return false;
@ -890,13 +872,13 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid shape size.";
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype";
return false;
}
auto dst_size = ShapeSize(args.device_shape) * size;
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
if (dst_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
return false;
@ -947,12 +929,12 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = ShapeSize(args.device_shape) * size;
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
@ -1005,12 +987,12 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
return false;
}
auto size = TypeIdSize(args.src_data_type);
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = ShapeSize(args.device_shape) * size;
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;

View File

@ -48,8 +48,6 @@ struct FormatArgs {
TypeId src_data_type;
};
size_t TypeIdSize(const TypeId data_type);
size_t ShapeSize(const std::vector<size_t> &shape);
size_t CubeSizeByType(const TypeId data_type);
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});

View File

@ -26,6 +26,7 @@
#include "runtime/device/convert_tensor_utils.h"
#include "ir/dtype/type.h"
#include "ir/tensor.h"
#include "abstract/utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
#include "utils/utils.h"
@ -298,7 +299,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_);
} else {
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
auto host = std::vector<uint8_t>(size_);
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_};
@ -413,11 +414,11 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
auto host_size = size;
if (type_id_ != type) {
auto device_dtype_size = trans::TypeIdSize(type_id_);
auto device_dtype_size = abstract::TypeIdSize(type_id_);
if (device_dtype_size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
}
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
size = device_dtype_size * shape_size;
}
size = GetCommonAlignSize(size);
@ -431,7 +432,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
} else {
auto host = std::vector<uint8_t>(size);
SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST);
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size};
sync_ok = trans::TransDataType(type_args, host_ptr);
if (!sync_ok) {
@ -500,7 +501,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
MS_LOG(ERROR) << "Trans format failed.";
return false;
}
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size};
sync_ok = trans::TransDataType(type_args, host_ptr);
if (!sync_ok) {
@ -537,7 +538,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size);
} else {
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransDataType(type_args, host_tmp.data());
@ -581,7 +582,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {
auto shape_size = trans::ShapeSize(host_shape);
auto shape_size = abstract::ShapeSize(host_shape);
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
auto host_tmp = std::vector<uint8_t>(size_);
sync_ok = trans::TransDataType(type_args, host_tmp.data());

View File

@ -22,6 +22,8 @@
#include <exception>
#include <algorithm>
#include <thread>
#include <stack>
#include "abstract/primitive_infer_map.h"
#include "debug/data_dump/e2e_dump_util.h"
#include "runtime/device/ascend/ascend_device_address.h"
#include "runtime/device/cpu/mpi/mpi_interface.h"
@ -39,6 +41,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/profiling/profiling_utils.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/ascend_memory_manager.h"
#include "debug/tensor_load.h"
#include "debug/data_dump/dump_json_parser.h"
@ -110,6 +113,34 @@ std::string GetRankId() {
}
return rank_id_str;
}
void InferShapeForNopNode(AnfNodePtr *input_node) {
MS_EXCEPTION_IF_NULL(*input_node);
if (!opt::IsNopNode(*input_node)) {
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
return;
}
MS_LOG(INFO) << "Infer shape for nop node.";
std::stack<AnfNodePtr> nop_road;
nop_road.push(*input_node);
while (true) {
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
auto in_node = input_node_with_idx.first;
MS_EXCEPTION_IF_NULL(in_node);
if (opt::IsNopNode(in_node)) {
nop_road.push(in_node);
*input_node = in_node;
} else {
break;
}
}
while (!nop_road.empty()) {
auto nop_node = nop_road.top();
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
nop_road.pop();
}
}
} // namespace
std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_;
@ -633,6 +664,15 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
}
if (dynamic_kernel->is_dynamic_shape()) {
auto kernel_node = dynamic_kernel->kernel_node();
MS_EXCEPTION_IF_NULL(kernel_node);
auto input_size = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i);
auto input_node = input_node_with_index.first;
MS_EXCEPTION_IF_NULL(input_node);
InferShapeForNopNode(&input_node);
}
dynamic_kernel->InferShape();
dynamic_kernel->UpdateArgs();
}

View File

@ -48,6 +48,7 @@ class DynamicKernel {
virtual void Initialize();
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }
int GetKernelType();
CNodePtr kernel_node() const { return cnode_ptr_; }
protected:
void RebuildDependTensor();