forked from mindspore-Ecosystem/mindspore
nop node infer shape
This commit is contained in:
parent
dabb82ec7a
commit
04136bdf66
|
@ -17,6 +17,7 @@
|
|||
#include "backend/kernel_compiler/rts/memcpy_async.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "abstract/utils.h"
|
||||
#include "runtime/mem.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "common/trans.h"
|
||||
|
@ -89,7 +90,7 @@ void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
|
|||
if (input_size != 1) {
|
||||
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
|
||||
}
|
||||
size_t type_size = trans::TypeIdSize(input_type_id_);
|
||||
size_t type_size = abstract::TypeIdSize(input_type_id_);
|
||||
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0);
|
||||
size_t total_size = 1;
|
||||
for (size_t i = 0; i < shape_i.size(); i++) {
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "ir/manager.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "common/trans.h"
|
||||
|
@ -1093,7 +1094,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
|
|||
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp},
|
||||
input_node.get());
|
||||
size = trans::ShapeSize(shape_tmp) * trans::TypeIdSize(tensor->data_type());
|
||||
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type());
|
||||
}
|
||||
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <numeric>
|
||||
#include <utility>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "runtime/device/convert_tensor_utils.h"
|
||||
|
@ -28,12 +29,6 @@
|
|||
namespace mindspore {
|
||||
namespace trans {
|
||||
enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc };
|
||||
const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1},
|
||||
{kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8},
|
||||
{kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2},
|
||||
{kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4},
|
||||
{kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}};
|
||||
|
||||
inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) {
|
||||
switch (size) {
|
||||
case 1:
|
||||
|
@ -117,8 +112,8 @@ const std::map<std::pair<TypeId, TypeId>, DataTypeTransMode> mode_map{
|
|||
{std::pair<TypeId, TypeId>(kNumberTypeBool, kNumberTypeFloat16), FROM_BOOL_TO_FLOAT16}};
|
||||
|
||||
void CheckMemSize(const TypeIdArgs &args) {
|
||||
auto src_type_size = TypeIdSize(args.host_data_type);
|
||||
auto dst_type_size = TypeIdSize(args.device_data_type);
|
||||
auto src_type_size = abstract::TypeIdSize(args.host_data_type);
|
||||
auto dst_type_size = abstract::TypeIdSize(args.device_data_type);
|
||||
if (src_type_size < 1 || dst_type_size < 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid src or dst data type.";
|
||||
}
|
||||
|
@ -192,7 +187,7 @@ bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const
|
|||
|
||||
size_t CubeSizeByType(const TypeId data_type) {
|
||||
const size_t default_error = 0;
|
||||
auto dt_size = TypeIdSize(data_type);
|
||||
auto dt_size = abstract::TypeIdSize(data_type);
|
||||
if (dt_size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return default_error;
|
||||
|
@ -202,19 +197,6 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|||
return kCubeSize;
|
||||
}
|
||||
|
||||
size_t ShapeSize(const std::vector<size_t> &shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
size_t TypeIdSize(const TypeId data_type) {
|
||||
const size_t unsupported_type_error = 0;
|
||||
auto iter = type_map.find(data_type);
|
||||
if (iter != type_map.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return unsupported_type_error;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool CheckDims(const std::vector<size_t> &shape) {
|
||||
if (shape.size() != kNchwDims) {
|
||||
|
@ -477,12 +459,12 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(size);
|
||||
MS_EXCEPTION_IF_NULL(total_size);
|
||||
*size = TypeIdSize(args.src_data_type);
|
||||
*size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (*size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
*total_size = ShapeSize(args.device_shape) * (*size);
|
||||
*total_size = abstract::ShapeSize(args.device_shape) * (*size);
|
||||
if (*total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -516,7 +498,7 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (TypeIdSize(args.src_data_type) < 1) {
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
|
@ -538,7 +520,7 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|||
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
||||
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
|
||||
MS_LOG(DEBUG) << "Start trans format.";
|
||||
if (TypeIdSize(args.src_data_type) < 1) {
|
||||
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
||||
MS_LOG(ERROR) << "Invalid datatype..";
|
||||
return false;
|
||||
}
|
||||
|
@ -624,7 +606,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
|
@ -685,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -828,13 +810,13 @@ bool NchwToFracNz(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = ShapeSize(args.device_shape) * size;
|
||||
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -890,13 +872,13 @@ bool FracNzToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid shape size.";
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dst_size = ShapeSize(args.device_shape) * size;
|
||||
auto dst_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (dst_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << dst_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -947,12 +929,12 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
@ -1005,12 +987,12 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) {
|
|||
MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims;
|
||||
return false;
|
||||
}
|
||||
auto size = TypeIdSize(args.src_data_type);
|
||||
auto size = abstract::TypeIdSize(args.src_data_type);
|
||||
if (size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
return false;
|
||||
}
|
||||
auto total_size = ShapeSize(args.device_shape) * size;
|
||||
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
||||
if (total_size != args.device_size) {
|
||||
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
||||
return false;
|
||||
|
|
|
@ -48,8 +48,6 @@ struct FormatArgs {
|
|||
TypeId src_data_type;
|
||||
};
|
||||
|
||||
size_t TypeIdSize(const TypeId data_type);
|
||||
size_t ShapeSize(const std::vector<size_t> &shape);
|
||||
size_t CubeSizeByType(const TypeId data_type);
|
||||
|
||||
std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std::vector<Axis> &padding_axis = {});
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "runtime/device/convert_tensor_utils.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
|
||||
#include "utils/utils.h"
|
||||
|
@ -298,7 +299,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size
|
|||
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
||||
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_);
|
||||
} else {
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
auto host = std::vector<uint8_t>(size_);
|
||||
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size_};
|
||||
|
@ -413,11 +414,11 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||
auto host_size = size;
|
||||
if (type_id_ != type) {
|
||||
auto device_dtype_size = trans::TypeIdSize(type_id_);
|
||||
auto device_dtype_size = abstract::TypeIdSize(type_id_);
|
||||
if (device_dtype_size < 1) {
|
||||
MS_LOG(ERROR) << "Illegal dtype.";
|
||||
}
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
size = device_dtype_size * shape_size;
|
||||
}
|
||||
size = GetCommonAlignSize(size);
|
||||
|
@ -431,7 +432,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const
|
|||
} else {
|
||||
auto host = std::vector<uint8_t>(size);
|
||||
SyncMemory(host.data(), output_address->GetPtr(), size, RT_MEMCPY_DEVICE_TO_HOST);
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, host_size};
|
||||
sync_ok = trans::TransDataType(type_args, host_ptr);
|
||||
if (!sync_ok) {
|
||||
|
@ -500,7 +501,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
|
|||
MS_LOG(ERROR) << "Trans format failed.";
|
||||
return false;
|
||||
}
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size};
|
||||
sync_ok = trans::TransDataType(type_args, host_ptr);
|
||||
if (!sync_ok) {
|
||||
|
@ -537,7 +538,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size
|
|||
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
||||
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size);
|
||||
} else {
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
||||
|
@ -581,7 +582,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
|
|||
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
||||
}
|
||||
if (type_id_ != type) {
|
||||
auto shape_size = trans::ShapeSize(host_shape);
|
||||
auto shape_size = abstract::ShapeSize(host_shape);
|
||||
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size};
|
||||
auto host_tmp = std::vector<uint8_t>(size_);
|
||||
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
||||
|
|
|
@ -22,6 +22,8 @@
|
|||
#include <exception>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <stack>
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "debug/data_dump/e2e_dump_util.h"
|
||||
#include "runtime/device/ascend/ascend_device_address.h"
|
||||
#include "runtime/device/cpu/mpi/mpi_interface.h"
|
||||
|
@ -39,6 +41,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ascend/profiling/profiling_utils.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/ascend_memory_manager.h"
|
||||
#include "debug/tensor_load.h"
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
|
@ -110,6 +113,34 @@ std::string GetRankId() {
|
|||
}
|
||||
return rank_id_str;
|
||||
}
|
||||
|
||||
void InferShapeForNopNode(AnfNodePtr *input_node) {
|
||||
MS_EXCEPTION_IF_NULL(*input_node);
|
||||
if (!opt::IsNopNode(*input_node)) {
|
||||
MS_LOG(INFO) << "Input node is not a nop node, no need infer.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Infer shape for nop node.";
|
||||
std::stack<AnfNodePtr> nop_road;
|
||||
nop_road.push(*input_node);
|
||||
|
||||
while (true) {
|
||||
auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(*input_node, 0);
|
||||
auto in_node = input_node_with_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(in_node);
|
||||
if (opt::IsNopNode(in_node)) {
|
||||
nop_road.push(in_node);
|
||||
*input_node = in_node;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (!nop_road.empty()) {
|
||||
auto nop_node = nop_road.top();
|
||||
AnfAlgo::InferShape(nop_node->cast<CNodePtr>());
|
||||
nop_road.pop();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<rtExceptionInfo> AscendKernelRuntime::exception_infoes_;
|
||||
|
@ -633,6 +664,15 @@ bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *grap
|
|||
}
|
||||
|
||||
if (dynamic_kernel->is_dynamic_shape()) {
|
||||
auto kernel_node = dynamic_kernel->kernel_node();
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto input_size = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel_node, i);
|
||||
auto input_node = input_node_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
InferShapeForNopNode(&input_node);
|
||||
}
|
||||
dynamic_kernel->InferShape();
|
||||
dynamic_kernel->UpdateArgs();
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ class DynamicKernel {
|
|||
virtual void Initialize();
|
||||
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }
|
||||
int GetKernelType();
|
||||
CNodePtr kernel_node() const { return cnode_ptr_; }
|
||||
|
||||
protected:
|
||||
void RebuildDependTensor();
|
||||
|
|
Loading…
Reference in New Issue