!6728 [Ascend][DynamicShape] Dynamic shape feature

Merge pull request !6728 from caifubi/dynamic_shape_share_2
This commit is contained in:
mindspore-ci-bot 2020-10-12 09:07:39 +08:00 committed by Gitee
commit c951d42c2c
160 changed files with 4618 additions and 476 deletions

View File

@ -18,6 +18,7 @@ import os
import sys import sys
from te.platform.cce_conf import te_set_version from te.platform.cce_conf import te_set_version
from te.platform.fusion_util import fusion_op from te.platform.fusion_util import fusion_op
import te
from common import check_kernel_info, get_args, get_build_in_impl_path from common import check_kernel_info, get_args, get_build_in_impl_path
build_in_impl_path = get_build_in_impl_path() build_in_impl_path = get_build_in_impl_path()
@ -38,6 +39,16 @@ def _initialize(impl_path):
sys.path.insert(0, op_module_name) sys.path.insert(0, op_module_name)
def _replace_range(args):
for arg in args:
if not arg.__contains__('range'):
continue
shape_range = arg["range"]
for range_item in shape_range:
for index, value in enumerate(range_item):
if value < 0:
range_item[index] = None
def build_op(build_type, json_str): def build_op(build_type, json_str):
""" """
call op functions with function name and input args json_str call op functions with function name and input args json_str
@ -71,11 +82,18 @@ def build_op(build_type, json_str):
outputs_args = get_args(kernel_info['op_info'], 'outputs') outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs') attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name'] kernel_name = kernel_info['op_info']['kernel_name']
is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
if is_dynamic_shape:
_replace_range(inputs_args)
_replace_range(outputs_args)
if custom_flag: if custom_flag:
op_module = __import__(op_name) op_module = __import__(op_name)
else: else:
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0) if is_dynamic_shape:
op_module = __import__("impl.dynamic."+op_name, globals(), locals(), [op_name], 0)
else:
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0)
# get function # get function
if build_type == op_build: if build_type == op_build:
if custom_flag: if custom_flag:
@ -92,7 +110,12 @@ def build_op(build_type, json_str):
if kernel_name[0:19] == "bounding_box_encode": if kernel_name[0:19] == "bounding_box_encode":
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name) return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name)
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) if is_dynamic_shape:
with te.op.dynamic():
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
return te.op.get_compile_info()
else:
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
except Exception as e: except Exception as e:
raise RuntimeError(e) raise RuntimeError(e)

View File

@ -78,6 +78,7 @@ def _check_supported(kernel_info):
""" """
try: try:
op_name = kernel_info['op_info']['name'] op_name = kernel_info['op_info']['name']
is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
impl_path = build_in_impl_path impl_path = build_in_impl_path
custom_flag = False custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
@ -92,8 +93,11 @@ def _check_supported(kernel_info):
if custom_flag: if custom_flag:
op_module = __import__(op_name) op_module = __import__(op_name)
elif is_dynamic_shape:
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
else: else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function # get function
if not hasattr(op_module, "check_supported"): if not hasattr(op_module, "check_supported"):
return "" return ""

View File

@ -219,6 +219,7 @@ if (ENABLE_D)
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
set(ASCEND_DRIVER_BACK_PATH ${ASCEND_PATH}/driver/lib64/driver) set(ASCEND_DRIVER_BACK_PATH ${ASCEND_PATH}/driver/lib64/driver)
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64) set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
set(ASCEND_OPP_PATH ${ASCEND_PATH}/opp/op_impl/built-in/ai_core/tbe/op_tiling)
endif() endif()
MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}") MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}")
@ -228,7 +229,8 @@ if (ENABLE_D)
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH}) find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH})
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}) find_library(OPTILING optiling ${ASCEND_OPP_PATH})
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} ${OPTILING})
target_link_libraries(mindspore -Wl,--start-group proto_input ${PROFILING} mindspore::protobuf -Wl,--end-group) target_link_libraries(mindspore -Wl,--start-group proto_input ${PROFILING} mindspore::protobuf -Wl,--end-group)
elseif (CMAKE_SYSTEM_NAME MATCHES "Windows") elseif (CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece -Wl,--end-group) target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece -Wl,--end-group)
@ -258,6 +260,7 @@ if (ENABLE_D)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
elseif (ENABLE_GPU) elseif (ENABLE_GPU)
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/cuda/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/cuda/lib64)
endif () endif ()
@ -315,6 +318,8 @@ add_library(inference SHARED
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
${LOAD_ONNX_SRC} ${LOAD_ONNX_SRC}
) )
set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar) -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar)

View File

@ -15,6 +15,7 @@ if (ENABLE_D)
"akg/akg_kernel_attrs_process.cc" "akg/akg_kernel_attrs_process.cc"
"akg/akg_kernel_metadata.cc" "akg/akg_kernel_metadata.cc"
"tbe/*.cc" "tbe/*.cc"
"host/*.cc"
"aicpu/*.cc" "aicpu/*.cc"
"rts/*.cc" "rts/*.cc"
"hccl/*.cc" "hccl/*.cc"

View File

@ -289,51 +289,25 @@ bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
return true; return true;
} }
bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) { uint64_t SetExtInfoShapeType(char *ext_info_buf, uint64_t ext_info_offset) {
if (!anf_node->isa<CNode>()) {
return true;
}
if (!AnfAlgo::IsDynamicShape(anf_node)) {
return true;
}
MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope();
int32_t unknown_shape_type = UnknowShapeOpType::DEPEND_COMPUTE;
uint64_t ext_info_head_len = kExtInfoHeadSize;
std::string ext_info;
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
// 1.addr:unknown shape type
uint64_t ext_info_len = ext_info.size();
ext_info_len += ext_info_head_len + sizeof(int32_t);
// 2.addr:input ShapeAndType
ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType);
// 3.addr:output ShapeAndType
ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType);
uint64_t ext_info_offset = ext_info.size();
ext_info.resize(ext_info_len, 0);
char *ext_info_buf = ext_info.data();
// deal1: unknown shape type // deal1: unknown shape type
ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset); ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
info->infoType = FWK_ADPT_EXT_SHAPE_TYPE; info->infoType = FWK_ADPT_EXT_SHAPE_TYPE;
info->infoLen = sizeof(int32_t); info->infoLen = sizeof(int32_t);
ext_info_offset += ext_info_head_len; ext_info_offset += kExtInfoHeadSize;
int32_t *shape_type = reinterpret_cast<int32_t *>(ext_info_buf + ext_info_offset); int32_t *shape_type = reinterpret_cast<int32_t *>(ext_info_buf + ext_info_offset);
*shape_type = unknown_shape_type; *shape_type = UnknowShapeOpType::DEPEND_COMPUTE;
ext_info_offset += info->infoLen; ext_info_offset += info->infoLen;
return ext_info_offset;
}
uint64_t SetExtInfoInputShapeType(char *ext_info_buf, uint64_t ext_info_offset,
const std::shared_ptr<AnfNode> &anf_node, size_t input_num) {
// deal2:input ShapeAndType // deal2:input ShapeAndType
info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset); ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
info->infoType = FWK_ADPT_EXT_INPUT_SHAPE; info->infoType = FWK_ADPT_EXT_INPUT_SHAPE;
info->infoLen = input_num * sizeof(ShapeAndType); info->infoLen = input_num * sizeof(ShapeAndType);
ext_info_offset += ext_info_head_len; ext_info_offset += kExtInfoHeadSize;
ShapeAndType *inputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset); ShapeAndType *inputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
for (size_t input_index = 0; input_index < input_num; input_index++) { for (size_t input_index = 0; input_index < input_num; input_index++) {
@ -364,12 +338,16 @@ bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_p
} }
} }
ext_info_offset += info->infoLen; ext_info_offset += info->infoLen;
return ext_info_offset;
}
uint64_t SetExtInfoOutputShapeType(char *ext_info_buf, uint64_t ext_info_offset,
const std::shared_ptr<AnfNode> &anf_node, size_t output_num) {
// deal3:output ShapeAndType // deal3:output ShapeAndType
info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset); ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
info->infoType = FWK_ADPT_EXT_OUTPUT_SHAPE; info->infoType = FWK_ADPT_EXT_OUTPUT_SHAPE;
info->infoLen = output_num * sizeof(ShapeAndType); info->infoLen = output_num * sizeof(ShapeAndType);
ext_info_offset += ext_info_head_len; ext_info_offset += kExtInfoHeadSize;
ShapeAndType *outputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset); ShapeAndType *outputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
for (size_t output_index = 0; output_index < output_num; output_index++) { for (size_t output_index = 0; output_index < output_num; output_index++) {
@ -387,6 +365,47 @@ bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_p
} }
} }
ext_info_offset += info->infoLen;
return ext_info_offset;
}
bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
if (!anf_node->isa<CNode>()) {
return true;
}
if (!AnfAlgo::IsDynamicShape(anf_node)) {
return true;
}
MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope();
uint64_t ext_info_head_len = kExtInfoHeadSize;
std::string ext_info;
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
// 1.addr:unknown shape type
uint64_t ext_info_len = ext_info.size();
ext_info_len += ext_info_head_len + sizeof(int32_t);
// 2.addr:input ShapeAndType
ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType);
// 3.addr:output ShapeAndType
ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType);
uint64_t ext_info_offset = ext_info.size();
ext_info.resize(ext_info_len, 0);
char *ext_info_buf = ext_info.data();
ext_info_offset = SetExtInfoShapeType(ext_info_buf, ext_info_offset);
ext_info_offset = SetExtInfoInputShapeType(ext_info_buf, ext_info_offset, anf_node, input_num);
ext_info_offset = SetExtInfoOutputShapeType(ext_info_buf, ext_info_offset, anf_node, output_num);
MS_LOG(INFO) << "Check ext_info_len:" << ext_info_len << " ext_info_offset:" << ext_info_offset;
// set ext info // set ext info
kernel_mod_ptr->SetExtInfo(ext_info); kernel_mod_ptr->SetExtInfo(ext_info);
return true; return true;

View File

@ -26,8 +26,13 @@
#include "utils/convert_utils.h" #include "utils/convert_utils.h"
#include "backend/kernel_compiler/aicpu/aicpu_util.h" #include "backend/kernel_compiler/aicpu/aicpu_util.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>; using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>;
using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel;
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -93,7 +98,7 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs
param_len += node_def_len; param_len += node_def_len;
param_len += sizeof(uint32_t); param_len += sizeof(uint32_t);
AicpuParamHead aicpu_param_head; AicpuParamHead aicpu_param_head{};
aicpu_param_head.length = param_len; aicpu_param_head.length = param_len;
aicpu_param_head.ioAddrNum = io_addrs_num; aicpu_param_head.ioAddrNum = io_addrs_num;
@ -178,5 +183,15 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; MS_LOG(INFO) << "AicpuOpKernelMod GenTask end";
return {task_info_ptr}; return {task_info_ptr};
} }
device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
CreateCpuKernelInfo(kernel_inputs, kernel_outputs);
return std::make_shared<AicpuDynamicKernel>(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_);
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -31,6 +31,7 @@ class AicpuOpKernelMod : public AscendKernelMod {
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
void SetInputList(const std::vector<int64_t> &inputList); void SetInputList(const std::vector<int64_t> &inputList);
void SetOutputList(const std::vector<int64_t> &outputList); void SetOutputList(const std::vector<int64_t> &outputList);

View File

@ -20,7 +20,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
static std::map<int32_t, int32_t> MS_PROTO_DATA_TYPE_MAP = { static const std::map<int32_t, int32_t> kMsProtoDataTypeMap = {
{mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN},
{mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL},
{mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32},
@ -39,14 +39,38 @@ static std::map<int32_t, int32_t> MS_PROTO_DATA_TYPE_MAP = {
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
}; };
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
{mindspore::DataType::MS_UNKNOWN, mindspore::TypeId::kTypeUnknown},
{mindspore::DataType::MS_BOOL, mindspore::TypeId::kNumberTypeBool},
{mindspore::DataType::MS_INT32, mindspore::TypeId::kNumberTypeInt32},
{mindspore::DataType::MS_INT8, mindspore::TypeId::kNumberTypeInt8},
{mindspore::DataType::MS_INT16, mindspore::TypeId::kNumberTypeInt16},
{mindspore::DataType::MS_INT64, mindspore::TypeId::kNumberTypeInt64},
{mindspore::DataType::MS_UINT8, mindspore::TypeId::kNumberTypeUInt8},
{mindspore::DataType::MS_UINT16, mindspore::TypeId::kNumberTypeUInt16},
{mindspore::DataType::MS_UINT32, mindspore::TypeId::kNumberTypeUInt32},
{mindspore::DataType::MS_UINT64, mindspore::TypeId::kNumberTypeUInt64},
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
};
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); auto iter = kMsProtoDataTypeMap.find(ms_type);
if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { if (iter == kMsProtoDataTypeMap.end()) {
return MS_PROTO_DATA_TYPE_MAP[ms_type];
} else {
MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast<int>(ms_type); MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast<int>(ms_type);
return -1; return -1;
} }
return iter->second;
}
int AicpuOpUtil::ProtoTypeToMsType(int proto_type) {
auto iter = kProtoDataTypeToMsDataTypeMap.find(proto_type);
if (iter == kProtoDataTypeToMsDataTypeMap.end()) {
MS_LOG(ERROR) << "UnSupported proto_type value:" << proto_type;
return -1;
}
return iter->second;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -55,13 +55,6 @@ struct AicpuParamHead {
uint64_t extInfoAddr; // extInfo address uint64_t extInfoAddr; // extInfo address
} __attribute__((packed)); } __attribute__((packed));
const uint32_t kExtInfoHeadSize = 8;
struct ExtInfo {
int32_t infoType; // extend type
uint32_t infoLen; // length for infoMsg
char infoMsg[0]; // extend value
} __attribute__((packed));
// Extent info ShapeAndType // Extent info ShapeAndType
const uint32_t kMaxShapeDims = 8; const uint32_t kMaxShapeDims = 8;
struct ShapeAndType { struct ShapeAndType {
@ -69,6 +62,14 @@ struct ShapeAndType {
int64_t dims[kMaxShapeDims]; int64_t dims[kMaxShapeDims];
} __attribute__((packed)); } __attribute__((packed));
// Extend info structure for extInfoAddr
const uint32_t kExtInfoHeadSize = 8;
struct ExtInfo {
int32_t infoType; // extend type
uint32_t infoLen; // length for infoMsg
char infoMsg[0]; // extend value
} __attribute__((packed));
// Extend Info type for task // Extend Info type for task
enum FWKTaskExtInfoType { enum FWKTaskExtInfoType {
FWK_ADPT_EXT_SHAPE_TYPE = 0, FWK_ADPT_EXT_SHAPE_TYPE = 0,
@ -88,6 +89,7 @@ enum UnknowShapeOpType {
class AicpuOpUtil { class AicpuOpUtil {
public: public:
static int MsTypeToProtoType(TypeId ms_type); static int MsTypeToProtoType(TypeId ms_type);
static int ProtoTypeToMsType(int proto_type);
private: private:
// kernel id // kernel id

View File

@ -15,15 +15,34 @@
*/ */
#include "backend/kernel_compiler/hccl/hccl_kernel.h" #include "backend/kernel_compiler/hccl/hccl_kernel.h"
#include <map>
#include "runtime/device/ascend/tasksink/runtime_utils.h" #include "runtime/device/ascend/tasksink/runtime_utils.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>; using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
using ge::model_runner::HcclTaskInfo; using ge::model_runner::HcclTaskInfo;
using mindspore::device::ascend::tasksink::RuntimeUtils; using mindspore::device::ascend::tasksink::RuntimeUtils;
namespace {
static std::map<std::string, std::string> kMsOpNameToHcomHcclType = {
{mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce},
{mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather},
{mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast},
{mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}};
std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
auto iter = kMsOpNameToHcomHcclType.find(ms_op_type);
if (iter == kMsOpNameToHcomHcclType.end()) {
MS_LOG(EXCEPTION) << "Invalid MsOpType:" << ms_op_type;
}
return iter->second;
}
} // namespace
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
@ -156,5 +175,30 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
MS_EXCEPTION_IF_NULL(task_info_ptr); MS_EXCEPTION_IF_NULL(task_info_ptr);
return {task_info_ptr}; return {task_info_ptr};
} }
device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList inputs;
AddressPtrList workspaces;
AddressPtrList outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs);
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_));
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Hccl kernel input is empty";
}
if (hccl_data_type_list_.empty()) {
MS_LOG(EXCEPTION) << "Hccl data type list is empty";
}
MS_EXCEPTION_IF_NULL(inputs.at(0));
auto input_data_addr = inputs.at(0)->addr;
MS_EXCEPTION_IF_NULL(outputs.at(0));
auto output_data_addr = outputs.at(0)->addr;
HcclDataType data_type = hccl_data_type_list_[0];
auto executor = std::make_shared<device::ascend::HcclDynamicKernel>(
hccl_type, input_data_addr, output_data_addr, hccl_count_, data_type, op_type_, root_id_, stream_ptr, cnode_ptr);
return executor;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -41,6 +41,7 @@ class HcclKernel : public AscendKernelMod {
const std::vector<size_t> &GetWorkspaceSizeList() const override; const std::vector<size_t> &GetWorkspaceSizeList() const override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
protected: protected:
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_; std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_;

View File

@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/host/dynamic_shape_kernel.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
void DynamicShapeKernel::Execute() {
MS_LOG(INFO) << "Execute DynamicShapeKernel Start";
auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num;
}
auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0);
auto output_shape = std::vector<int>(SizeToInt(prev_output_shape.size()));
auto output_type = TypeId::kNumberTypeInt32;
auto output_tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, output_shape);
auto data_ptr = static_cast<int32_t *>(output_tensor_for_sync->data_c());
for (size_t i = 0; i < prev_output_shape.size(); ++i) {
MS_LOG(INFO) << "DEBUG prev_output_shape[" << i << "]:" << prev_output_shape[i];
*(data_ptr + i) = prev_output_shape[i];
}
auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, 0);
MS_EXCEPTION_IF_NULL(output_addr);
output_addr->SyncHostToDevice(output_shape, LongToSize(output_tensor_for_sync->data().nbytes()),
output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c());
MS_LOG(INFO) << "Execute DynamicShapeKernel End";
}
device::DynamicKernelPtr DynamicShapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
return std::make_shared<DynamicShapeKernel>(stream_ptr, cnode_ptr);
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
#include "backend/kernel_compiler/host/host_kernel_mod.h"
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
namespace mindspore {
namespace kernel {
class DynamicShapeKernel : public HostDynamicKernel {
public:
DynamicShapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
~DynamicShapeKernel() override = default;
void Execute() override;
};
class DynamicShapeKernelMod : public HostKernelMod {
public:
DynamicShapeKernelMod() = default;
~DynamicShapeKernelMod() override = default;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
};
MS_HOST_REG_KERNEL(DynamicShape, DynamicShapeKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_

View File

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/host/host_kernel_build.h"
#include <string>
#include "runtime/device/kernel_runtime.h"
#include "backend/kernel_compiler/host/host_kernel_mod.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
KernelModPtr HostOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
std::string opname = AnfAlgo::GetCNodeName(anf_node);
MS_LOG(INFO) << "Host op [" << opname << "]";
auto kerPtr = HostKernelFactory::Get(opname);
if (kerPtr == nullptr) {
MS_LOG(ERROR) << "Host can't find Kernel[" << opname << "]";
return nullptr;
}
if (!kerPtr->Init(anf_node)) {
MS_LOG(ERROR) << "Host Kernel initialize failed!";
return nullptr;
}
return kerPtr;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_
#include <memory>
#include "backend/kernel_compiler/kernel.h"
namespace mindspore {
namespace kernel {
KernelModPtr HostOpBuild(const std::shared_ptr<AnfNode> &anf_node);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/host/host_kernel_metadata.h"
#include <memory>
#include <string>
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
constexpr auto kDynamicShape = "DynamicShape";
void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
MS_LOG(INFO) << "HostMetadataInfo.";
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
if (op_name != kDynamicShape) {
MS_LOG(DEBUG) << "Host does not have op [" << op_name << "]";
return;
}
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
builder.SetInputsFormat(inputs_format);
builder.SetInputsDeviceType(inputs_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputsDeviceType(outputs_type);
builder.SetKernelType(HOST_KERNEL);
kernel_info_list->push_back(builder.Build());
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_
#include <string>
#include <vector>
#include <memory>
#include "backend/kernel_compiler/kernel_build_info.h"
namespace mindspore {
namespace kernel {
void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_

View File

@ -0,0 +1,98 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/host/host_kernel_mod.h"
#include <memory>
#include <vector>
#include <string>
#include <utility>
#include "runtime/mem.h"
#include "utils/ms_context.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
namespace mindspore {
namespace kernel {
void HostKernelFactory::Registe(const std::string &name, HostKernelCreater &&fun) {
hostKernelMap_.emplace(name, std::move(fun));
}
std::shared_ptr<HostKernelMod> HostKernelFactory::Get(const std::string &name) {
const auto &map = Get().hostKernelMap_;
auto it = map.find(name);
if (it != map.end() && it->second) {
return (it->second)();
}
return nullptr;
}
HostKernelFactory &HostKernelFactory::Get() {
static HostKernelFactory instance;
return instance;
}
const std::vector<size_t> &HostKernelMod::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &HostKernelMod::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &HostKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool HostKernelMod::Init(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
for (size_t i = 0; i < input_num; i++) {
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
input_size_list_.push_back(LongToSize(size_i));
}
for (size_t i = 0; i < output_num; i++) {
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
MS_EXCEPTION_IF_NULL(type_ptr);
int64_t size_i = 1;
for (size_t j = 0; j < shape_i.size(); j++) {
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
}
size_t type_byte = GetTypeByte(type_ptr);
if (type_byte == 0) {
return false;
}
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
output_size_list_.push_back(LongToSize(size_i));
}
return true;
}
bool HostKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
return true;
}
std::vector<TaskInfoPtr> HostKernelMod::GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, uint32_t) {
return {};
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_
#include <vector>
#include <memory>
#include <string>
#include <map>
#include <utility>
#include "backend/kernel_compiler/ascend_kernel_mod.h"
namespace mindspore {
namespace kernel {
class HostKernelMod : public AscendKernelMod {
public:
HostKernelMod() = default;
~HostKernelMod() override = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, uint32_t) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override = 0;
bool Init(const AnfNodePtr &anf_node);
protected:
AnfNodePtr anf_node_;
std::string op_name_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
using HostKernelModPtr = std::shared_ptr<HostKernelMod>;
using HostKernelModPtrList = std::vector<HostKernelModPtr>;
using HostKernelCreater = std::function<std::shared_ptr<HostKernelMod>()>;
class HostKernelFactory {
HostKernelFactory() = default;
~HostKernelFactory() = default;
public:
static HostKernelFactory &Get();
void Registe(const string &name, HostKernelCreater &&fun);
static std::shared_ptr<HostKernelMod> Get(const string &name);
private:
std::map<string, HostKernelCreater> hostKernelMap_;
};
class _HostKernelRegister {
public:
_HostKernelRegister(const string &name, HostKernelCreater &&fun) {
HostKernelFactory::Get().Registe(name, std::move(fun));
}
~_HostKernelRegister() = default;
};
#define _MS_HOST_REG_KERNEL_REG(KNAME, clazz) \
static_assert(std::is_base_of<HostKernelMod, clazz>::value, " must be base of HostKernelMod"); \
static const _HostKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \
std::shared_ptr<clazz> ptr = nullptr; \
ptr = std::make_shared<clazz>(); \
MS_EXCEPTION_IF_NULL(ptr); \
return ptr; \
});
#define MS_HOST_REG_KERNEL(KNAME, clazz) _MS_HOST_REG_KERNEL_REG(KNAME, clazz)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_

View File

@ -174,6 +174,9 @@ void KernelPack::ParseKernelJson(const nlohmann::json &js) {
kernel_json_info_.block_dim = js["blockDim"]; kernel_json_info_.block_dim = js["blockDim"];
kernel_json_info_.kernel_name = js["kernelName"]; kernel_json_info_.kernel_name = js["kernelName"];
kernel_json_info_.magic = js["magic"]; kernel_json_info_.magic = js["magic"];
if (js.contains("opParaSize")) {
kernel_json_info_.op_para_size = js["opParaSize"];
}
if (js.find("parameters") != js.end()) { if (js.find("parameters") != js.end()) {
if (!js.at("parameters").is_array()) { if (!js.at("parameters").is_array()) {
MS_LOG(DEBUG) << "Format error!,parameters should be array."; MS_LOG(DEBUG) << "Format error!,parameters should be array.";

View File

@ -25,9 +25,18 @@
#include "ir/tensor.h" #include "ir/tensor.h"
#include "abstract/dshape.h" #include "abstract/dshape.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "runtime/device/executor/dynamic_kernel.h"
namespace mindspore { namespace mindspore {
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; enum KernelType : int {
UNKNOWN_KERNEL_TYPE = 0,
AKG_KERNEL,
AICPU_KERNEL,
RT_KERNEL,
HCCL_KERNEL,
TBE_KERNEL,
HOST_KERNEL
};
namespace kernel { namespace kernel {
// Supported fusion type // Supported fusion type
@ -69,7 +78,8 @@ struct KernelJsonInfo {
std::vector<size_t> parameters; std::vector<size_t> parameters;
std::string sha256; std::string sha256;
std::vector<size_t> workspaces; std::vector<size_t> workspaces;
KernelJsonInfo() : block_dim(0) {} uint32_t op_para_size;
KernelJsonInfo() : block_dim(0), op_para_size(0) {}
}; };
class KernelPack { class KernelPack {
@ -118,6 +128,7 @@ class KernelMod {
virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0; virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0;
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0; const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; }
virtual std::vector<size_t> GenParameters() { return {}; } virtual std::vector<size_t> GenParameters() { return {}; }
virtual void ReleaseResource() {} virtual void ReleaseResource() {}

View File

@ -83,8 +83,8 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
while (!build_manger->IsAllTaskFinish()) { while (!build_manger->IsAllTaskFinish()) {
int task_id = -1; int task_id = -1;
std::string task_result; std::string task_result;
std::string pre_build_result; std::string build_result;
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result);
if (!ret) { if (!ret) {
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
} }
@ -94,7 +94,7 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
<< " change to single op build."; << " change to single op build.";
build_failed_num++; build_failed_num++;
} }
auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, build_result, false);
if (kernel_mod_item.second != nullptr) { if (kernel_mod_item.second != nullptr) {
(void)kernel_mod_ret.emplace(kernel_mod_item); (void)kernel_mod_ret.emplace(kernel_mod_item);
} }

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" #include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h"
#include "backend/kernel_compiler/host/host_kernel_metadata.h"
#include "backend/kernel_compiler/rts/rt_kernel_info.h" #include "backend/kernel_compiler/rts/rt_kernel_info.h"
#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" #include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
@ -86,6 +87,9 @@ void KernelQueryAll(const CNodePtr &kernel_node,
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
HcclMetadataInfo(kernel_node, kernel_info_list); HcclMetadataInfo(kernel_node, kernel_info_list);
} }
if (kernel_info_list->empty()) {
HostMetadataInfo(kernel_node, kernel_info_list);
}
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
MS_EXCEPTION(NotExistsError) MS_EXCEPTION(NotExistsError)
<< "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:" << "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:"

View File

@ -102,6 +102,7 @@ class OpInfo {
kernel_name_ = opinfo.kernel_name(); kernel_name_ = opinfo.kernel_name();
partial_flag_ = opinfo.partial_flag_; partial_flag_ = opinfo.partial_flag_;
dynamic_format_ = opinfo.dynamic_format_; dynamic_format_ = opinfo.dynamic_format_;
dynamic_shape_ = opinfo.dynamic_shape_;
op_pattern_ = opinfo.op_pattern(); op_pattern_ = opinfo.op_pattern();
processor_ = opinfo.processor_; processor_ = opinfo.processor_;
for (const auto &attr : opinfo.attrs_ptr()) { for (const auto &attr : opinfo.attrs_ptr()) {
@ -122,12 +123,14 @@ class OpInfo {
std::string fusion_type() const { return fusion_type_; } std::string fusion_type() const { return fusion_type_; }
std::string kernel_name() const { return kernel_name_; } std::string kernel_name() const { return kernel_name_; }
OpPattern op_pattern() const { return op_pattern_; } OpPattern op_pattern() const { return op_pattern_; }
bool dynamic_shape() const { return dynamic_shape_; }
std::string processor() const { return processor_; } std::string processor() const { return processor_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; } const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_op_name(const std::string &op_name) { op_name_ = op_name; }
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
@ -149,7 +152,8 @@ class OpInfo {
void ClearOutputs() { (void)outputs_ptr_.clear(); } void ClearOutputs() { (void)outputs_ptr_.clear(); }
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const { bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
this->processor_ == other_info->processor_; this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
this->dynamic_shape_ == other_info->dynamic_shape_;
} }
private: private:
@ -163,6 +167,7 @@ class OpInfo {
std::string kernel_name_; std::string kernel_name_;
bool partial_flag_ = false; bool partial_flag_ = false;
bool dynamic_format_ = false; bool dynamic_format_ = false;
bool dynamic_shape_ = false;
OpPattern op_pattern_ = kCommonPattern; OpPattern op_pattern_ = kCommonPattern;
std::string processor_; std::string processor_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;

View File

@ -38,6 +38,7 @@ constexpr auto kDynamicFormat = "dynamicFormat";
constexpr auto kFormatAgnostic = "formatAgnostic"; constexpr auto kFormatAgnostic = "formatAgnostic";
constexpr auto kBroadcast = "broadcast"; constexpr auto kBroadcast = "broadcast";
constexpr auto kReduce = "reduce"; constexpr auto kReduce = "reduce";
constexpr auto kDynamicShape = "dynamic_shape";
constexpr auto kDtypeFormat = "dtype_format"; constexpr auto kDtypeFormat = "dtype_format";
constexpr auto kAttr = "attr"; constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs"; constexpr auto kIputs = "inputs";
@ -111,6 +112,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_kernel_name(obj.at(kKernelName));
op_info->set_partial_flag(obj.at(kPartialFlag)); op_info->set_partial_flag(obj.at(kPartialFlag));
if (obj.find(kDynamicShape) != obj.end()) {
op_info->set_dynamic_shape(obj.at(kDynamicShape));
}
if (obj.find(kOpPattern) != obj.end()) { if (obj.find(kOpPattern) != obj.end()) {
std::string op_pattern = obj.at(kOpPattern); std::string op_pattern = obj.at(kOpPattern);
auto find_iter = kOpPatternMap.find(op_pattern); auto find_iter = kOpPatternMap.find(op_pattern);
@ -322,7 +327,7 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply
return ret; return ret;
} }
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type, bool is_dynamic_shape) {
if (!OpLib::RegOpFromLocalInfo()) { if (!OpLib::RegOpFromLocalInfo()) {
MS_LOG(INFO) << "Warning reg local op info failed."; MS_LOG(INFO) << "Warning reg local op info failed.";
} }
@ -338,16 +343,20 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) { for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
auto &op_info = iter->second; auto &op_info = iter->second;
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
if (op_info->imply_type() != imply_type) { if (op_info->imply_type() != imply_type) {
continue; continue;
} }
if (imply_type == kAKG && op_info->processor() != target_processor) { if (imply_type == kAKG && op_info->processor() != target_processor) {
continue; continue;
} }
if (is_dynamic_shape && !op_info->dynamic_shape()) {
continue;
}
return op_info; return op_info;
} }
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
<< ", current op num: " << op_info_.size(); << ", current op num: " << op_info_.size() << " is_dynamic_shape:" << is_dynamic_shape;
return nullptr; return nullptr;
} }

View File

@ -32,7 +32,8 @@ class OpLib {
virtual ~OpLib() = default; virtual ~OpLib() = default;
static bool RegOp(const std::string &json_string, const std::string &impl_path); static bool RegOp(const std::string &json_string, const std::string &impl_path);
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); } static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); }
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type); static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type,
bool is_dynamic_shape = false);
static const std::multimap<std::string, std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; } static const std::multimap<std::string, std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
protected: protected:

View File

@ -21,9 +21,14 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "common/trans.h" #include "common/trans.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/device/kernel_runtime.h"
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
using ge::model_runner::MemcpyAsyncTaskInfo; using ge::model_runner::MemcpyAsyncTaskInfo;
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>; using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
using mindspore::device::ascend::MemcpyRtsDynamicKernel;
using MemcpyRtsDynamicKernelPtr = std::shared_ptr<MemcpyRtsDynamicKernel>;
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -122,6 +127,32 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
MS_EXCEPTION_IF_NULL(task_info_ptr); MS_EXCEPTION_IF_NULL(task_info_ptr);
return {task_info_ptr}; return {task_info_ptr};
} }
device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
if (kernel_inputs.size() != 1) {
MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one";
}
if (kernel_outputs.size() != 1) {
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
}
if (kernel_outputs[0]->size < kernel_inputs[0]->size) {
MS_LOG(EXCEPTION) << "Check rtMemcpyAsync destMax < src size";
}
// input x -> memcpy_async -> AllReduce
if (kernel_outputs[0]->size > kernel_inputs[0]->size) {
MS_LOG(WARNING) << "Check rtMemcpyAsync destMax > src size";
}
return std::make_shared<MemcpyRtsDynamicKernel>(stream_ptr, cnode_ptr, kernel_outputs[0]->addr,
kernel_outputs[0]->size, kernel_inputs[0]->addr,
kernel_inputs[0]->size);
}
const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,

View File

@ -34,6 +34,7 @@ class MemCpyAsyncKernel : public RtKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override; const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
private: private:
void GetInputOutputDataType(const AnfNodePtr &anf_node); void GetInputOutputDataType(const AnfNodePtr &anf_node);

View File

@ -21,8 +21,10 @@
#include "framework/ge_runtime/task_info.h" #include "framework/ge_runtime/task_info.h"
#include "runtime/device/ascend/profiling/profiling_utils.h" #include "runtime/device/ascend/profiling/profiling_utils.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo;
using mindspore::device::ascend::ProfilingRtsDynamicKernel;
using mindspore::device::ascend::ProfilingUtils; using mindspore::device::ascend::ProfilingUtils;
namespace mindspore { namespace mindspore {
@ -64,5 +66,9 @@ std::vector<TaskInfoPtr> ProfilingKernelMod::GenTask(const std::vector<AddressPt
std::make_shared<ProfilerTraceTaskInfo>(kernel_name_, stream_id, log_id_, notify_, flags_); std::make_shared<ProfilerTraceTaskInfo>(kernel_name_, stream_id, log_id_, notify_, flags_);
return {task_info_ptr}; return {task_info_ptr};
} }
device::DynamicKernelPtr ProfilingKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
return std::make_shared<ProfilingRtsDynamicKernel>(stream_ptr, cnode_ptr, log_id_, notify_, flags_);
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -27,6 +27,7 @@ class ProfilingKernelMod : public RtKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override; const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
bool Init(const AnfNodePtr &anf_node) override; bool Init(const AnfNodePtr &anf_node) override;
private: private:

View File

@ -29,157 +29,6 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace tbe { namespace tbe {
static std::map<string, string> tbe_func_adapter_map = {
{"softmax", "softmax_v2"},
{"log_softmax", "log_softmax_v2"},
{"apply_momentum", "apply_momentum_d"},
{"apply_ftrl", "apply_ftrl_d"},
{"re_lu6", "relu6"},
{"re_lu6_grad", "relu6_grad"},
{"re_lu", "relu"},
{"reverse_v2", "reverse_v2_d"},
{"re_luv2", "relu_v2"},
{"p_re_lu", "prelu"},
{"p_re_lu_grad", "prelu_grad"},
{"tensor_add", "add"},
{"reduce_mean", "reduce_mean_d"},
{"reduce_max", "reduce_max_d"},
{"reduce_min", "reduce_min_d"},
{"avg_pool_grad", "avg_pool_grad_d"},
{"avg_pool_grad_vm", "avg_pool_grad_d"},
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
{"depthwise_conv2d_native", "depthwise_conv2d"},
{"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"},
{"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"},
{"scatter_nd", "scatter_nd_d"},
{"tile", "tile_d"},
{"gather_v2", "gather_v2_d"},
{"sparse_gather_v2", "gather_v2_d"},
{"batch_mat_mul", "batch_matmul"},
{"b_n_training_reduce", "bn_training_reduce"},
{"b_n_training_update", "bn_training_update"},
{"b_n_training_update_v2", "bn_training_update_v2"},
{"b_n_training_update_v3", "bn_training_update_v3"},
{"b_n_training_reduce_grad", "bn_training_reduce_grad"},
{"b_n_training_update_grad", "bn_training_update_grad"},
{"b_n_infer", "bn_infer"},
{"b_n_infer_grad", "bn_infer_grad"},
{"b_n_inference", "bninference_d"},
{"n_pu_clear_float_status", "n_p_u_clear_float_status"},
{"n_pu_get_float_status", "n_p_u_get_float_status"},
{"n_pu_alloc_float_status", "n_p_u_alloc_float_status"},
{"dropout_do_mask", "drop_out_do_mask"},
{"strided_slice", "strided_slice_d"},
{"strided_slice_grad", "strided_slice_grad_d"},
{"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
{"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"},
{"apply_ada_max", "apply_ada_max_d"},
{"apply_adadelta", "apply_adadelta_d"},
{"apply_adagrad", "apply_adagrad_d"},
{"apply_adagrad_v2", "apply_adagradv2_d"},
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"},
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
{"apply_add_sign", "apply_add_sign_d"},
{"apply_power_sign", "apply_power_sign_d"},
{"apply_centered_rms_prop", "apply_centered_rms_prop_d"},
{"transpose", "transpose_d"},
{"fill", "fill_d"},
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
{"unsorted_segment_prod", "unsorted_segment_prod_d"},
{"concat", "concat_d"},
{"slice", "slice_d"},
{"reduce_sum", "reduce_sum_d"},
{"inplace_add", "inplace_add_d"},
{"inplace_sub", "inplace_sub_d"},
{"one_hot", "one_hot_d"},
{"sum", "reduce_sum_d"},
{"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
{"lamb_next_mv", "lamb_next_m_v"},
{"split", "split_d"},
{"split_v", "split_v_d"},
{"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
{"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
{"pad", "pad_d"},
{"argmax", "arg_max_d"},
{"argmin", "arg_min_d"},
{"space_to_batch", "space_to_batch_d"},
{"batch_to_space", "batch_to_space_d"},
{"space_to_batch_nd", "space_to_batch_nd_d"},
{"batch_to_space_nd", "batch_to_space_nd_d"},
{"resize_bilinear", "resize_bilinear_v2_d"},
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
{"adam", "apply_adam_d"},
{"r_oi_align", "roi_align"},
{"r_oi_align_grad", "roi_align_grad"},
{"i_ou", "iou"},
{"s_gd", "sgd"},
{"l_rn", "lrn"},
{"l_rn_grad", "lrn_grad"},
{"l_ars_update", "lars_v2_update"},
{"n_ms_with_mask", "nms_with_mask"},
{"square_sum_all", "square_sum_all"},
{"cum_sum", "cumsum_d"},
{"range", "range_d"},
{"lin_space", "lin_space_d"},
{"inv_grad", "inv_grad"},
{"apply_rms_prop", "apply_rms_prop_d"},
{"cum_prod", "cumprod_d"},
{"reduce_all", "reduce_all_d"},
{"reduce_any", "reduce_any_d"},
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
{"unsorted_segment_min", "unsorted_segment_min_d"},
{"reduce_prod", "reduce_prod_d"},
{"a_cos", "acos"},
{"a_cos_grad", "acos_grad"},
{"histogram_fixed_width", "histogram_fixed_width_d"},
{"broadcast_to", "broadcast_to_d"},
{"inplace_update", "inplace_update_d"},
{"i_fmr", "ifmr"},
{"matrix_diag", "matrix_diag_d"},
{"matrix_diag_part", "matrix_diag_part_d"},
{"matrix_set_diag", "matrix_set_diag_d"},
{"l_stm_input_grad", "lstm_input_grad"}};
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
if (func_name == nullptr) {
MS_LOG(EXCEPTION) << "func_name is null";
}
std::string name_tmp;
bool sub_head = false;
for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) {
if (islower(*iter)) {
sub_head = false;
}
if (isdigit(*iter)) {
sub_head = true;
}
if (isupper(*iter) && iter != func_name->begin()) {
if (!sub_head) {
(void)name_tmp.insert(name_tmp.end(), '_');
sub_head = true;
} else {
string::iterator iter_next = iter + 1;
if (iter_next != func_name->end()) {
if (islower(*iter_next)) {
(void)name_tmp.insert(name_tmp.end(), '_');
}
}
}
}
(void)name_tmp.insert(name_tmp.end(), *iter);
}
(void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower);
*func_name = name_tmp;
auto iter = tbe_func_adapter_map.find(*func_name);
if (iter != tbe_func_adapter_map.end()) {
MS_LOG(INFO) << "Map actual op from me: " << *func_name << " to tbe op: " << iter->second;
*func_name = iter->second;
}
}
std::unordered_set<std::string> input_order_adjusted_ops = { std::unordered_set<std::string> input_order_adjusted_ops = {
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};

View File

@ -35,7 +35,6 @@ class TbeAdapter {
public: public:
TbeAdapter() = default; TbeAdapter() = default;
~TbeAdapter() = default; ~TbeAdapter() = default;
static void NormalizeFuncName(std::string *func_name);
static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list, static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
nlohmann::json *inputs_json); nlohmann::json *inputs_json);

View File

@ -0,0 +1,139 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
namespace tbe {
bool TbeDynamicShapeUtil::IsDynamicShapeNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = AnfAlgo ::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_num; ++i) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
if (std::any_of(input_shape.begin(), input_shape.end(), [](const size_t &dim) { return dim < 0; })) {
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
return true;
}
}
auto output_num = AnfAlgo ::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; ++i) {
auto output_shape = AnfAlgo::GetOutputInferShape(cnode, i);
if (std::any_of(output_shape.begin(), output_shape.end(), [](const size_t &dim) { return dim < 0; })) {
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
return true;
}
}
return false;
}
bool TbeDynamicShapeUtil::IsDynamicShapeNode(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return IsDynamicShapeNode(cnode);
}
return false;
}
void TbeDynamicShapeUtil::SetDynamicShapeAttr(const CNodePtr &cnode) {
auto is_dyanmic_shape = IsDynamicShapeNode(cnode);
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(is_dyanmic_shape), cnode);
}
bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return GetDynamicShapeAttr(cnode);
}
return false;
}
bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto is_dynamic_shape = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
if (!is_dynamic_shape) {
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") does not has is_dynamic_shape attribute.";
return false;
}
is_dynamic_shape = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrIsDynamicShape);
return is_dynamic_shape;
}
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return FindOp(op_name, cnode);
}
return nullptr;
}
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto is_dynamic_shape = GetDynamicShapeAttr(cnode);
return mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE, is_dynamic_shape);
}
std::vector<std::pair<int, int>> TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto input_range_min = AnfAlgo::GetInputMinShape(anf_node, index);
auto input_range_max = AnfAlgo::GetInputMaxShape(anf_node, index);
if (input_range_min.size() != input_range_max.size()) {
MS_EXCEPTION(ArgumentError) << "Input range size is not equal, min size: " << input_range_min.size()
<< "max size: " << input_range_max.size();
}
if (input_range_min.empty() && input_range_max.empty()) {
return {{1, 1}};
}
std::vector<std::pair<int, int>> ret;
for (size_t i = 0; i < input_range_min.size(); ++i) {
ret.emplace_back(input_range_min[i], input_range_max[i]);
}
return ret;
}
std::vector<std::pair<int, int>> TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto output_range_min = AnfAlgo::GetOutputMinShape(anf_node, index);
auto output_range_max = AnfAlgo::GetOutputMaxShape(anf_node, index);
if (output_range_min.size() != output_range_max.size()) {
MS_EXCEPTION(ArgumentError) << "Onput range size is not equal, min size: " << output_range_min.size()
<< "max size: " << output_range_max.size();
}
if (output_range_max.empty() && output_range_min.empty()) {
return {{1, 1}};
}
std::vector<std::pair<int, int>> ret;
for (size_t i = 0; i < output_range_min.size(); ++i) {
ret.emplace_back(output_range_min[i], output_range_max[i]);
}
return ret;
}
} // namespace tbe
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "mindspore/core/ir/anf.h"
#include "backend/kernel_compiler/oplib/oplib.h"
namespace mindspore {
namespace kernel {
namespace tbe {
class TbeDynamicShapeUtil {
public:
TbeDynamicShapeUtil() = default;
~TbeDynamicShapeUtil() = default;
static bool IsDynamicShapeNode(const CNodePtr &cnode);
static bool IsDynamicShapeNode(const AnfNodePtr &anf_node);
static void SetDynamicShapeAttr(const CNodePtr &cnode);
static bool GetDynamicShapeAttr(const CNodePtr &cnode);
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
static std::vector<std::pair<int, int>> GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index);
static std::vector<std::pair<int, int>> GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index);
};
} // namespace tbe
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H

View File

@ -23,6 +23,7 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/tbe/tbe_adapter.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h"
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "runtime/dev.h" #include "runtime/dev.h"
@ -61,6 +62,7 @@ constexpr auto kJDataType = "data_type";
constexpr auto kJOutputIndex = "output_index"; constexpr auto kJOutputIndex = "output_index";
constexpr auto kJOutputDesc = "output_desc"; constexpr auto kJOutputDesc = "output_desc";
constexpr auto kJInputDesc = "input_desc"; constexpr auto kJInputDesc = "input_desc";
constexpr auto kJRange = "range";
constexpr auto kVTypeInt = "int"; constexpr auto kVTypeInt = "int";
constexpr auto kVTypeStr = "str"; constexpr auto kVTypeStr = "str";
constexpr auto kVTypeBool = "bool"; constexpr auto kVTypeBool = "bool";
@ -89,24 +91,21 @@ constexpr auto kJKwdArgs = "kwds_args";
constexpr auto kJListArgs = "list_args"; constexpr auto kJListArgs = "list_args";
constexpr auto kJSocVersion = "socVersion"; constexpr auto kJSocVersion = "socVersion";
constexpr auto kSOC_VERSION = "SOC_VERSION"; constexpr auto kSOC_VERSION = "SOC_VERSION";
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
nlohmann::json *kernel_json) { nlohmann::json *kernel_json) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(kernel_json); MS_EXCEPTION_IF_NULL(kernel_json);
std::string op_name = AnfAlgo::GetCNodeName(anf_node); std::string op_name = AnfAlgo::GetCNodeName(anf_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); auto op_info_ptr = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, anf_node);
MS_EXCEPTION_IF_NULL(op_info_ptr); MS_EXCEPTION_IF_NULL(op_info_ptr);
(*kernel_json)[kPlatform] = kPlatTBE; (*kernel_json)[kPlatform] = kPlatTBE;
(*kernel_json)[kGenModel] = kSingle; (*kernel_json)[kGenModel] = kSingle;
(*kernel_json)[kImplPath] = op_info_ptr->impl_path(); (*kernel_json)[kImplPath] = op_info_ptr->impl_path();
nlohmann::json op_info_json; nlohmann::json op_info_json;
if (op_info_ptr->impl_path().empty()) { op_info_json[kJIsDynamicShape] = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node->cast<CNodePtr>());
tbe::TbeAdapter::NormalizeFuncName(&op_name); op_info_json[kJName] = op_info_ptr->kernel_name();
} else {
op_name = op_info_ptr->kernel_name();
}
op_info_json[kJName] = op_name;
// generate inputs json // generate inputs json
nlohmann::json inputs_json; nlohmann::json inputs_json;
if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
@ -180,6 +179,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
input_desc_json[kJFormat] = format; input_desc_json[kJFormat] = format;
input_desc_json[kJValid] = value; input_desc_json[kJValid] = value;
input_desc_json[kJParamType] = input_ptr->param_type(); input_desc_json[kJParamType] = input_ptr->param_type();
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
input_list->emplace_back(input_desc_json); input_list->emplace_back(input_desc_json);
} }
return true; return true;
@ -359,8 +359,13 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
for (size_t i = 0; i < output_obj_num; i++) { for (size_t i = 0; i < output_obj_num; i++) {
auto dtype = GetDeviceOutputType(anf_node, *output_idx); auto dtype = GetDeviceOutputType(anf_node, *output_idx);
auto format = GetDeviceOutputFormat(anf_node, *output_idx); auto format = GetDeviceOutputFormat(anf_node, *output_idx);
auto shape = GetDeviceOutputShape(anf_node, *output_idx);
std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); std::vector<int64_t> shape;
AnfAlgo::GetRealDynamicShape(GetDeviceOutputShape(anf_node, *output_idx), NOT_NULL(&shape));
std::vector<int64_t> ori_shape;
AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape));
// std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
if (ori_shape.empty()) { if (ori_shape.empty()) {
ori_shape.emplace_back(1); ori_shape.emplace_back(1);
} }
@ -373,6 +378,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
output_obj[kJName] = output_ptr->name(); output_obj[kJName] = output_ptr->name();
output_obj[kJValid] = true; output_obj[kJValid] = true;
output_obj[kJParamType] = output_ptr->param_type(); output_obj[kJParamType] = output_ptr->param_type();
output_obj[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, *output_idx);
output_list->emplace_back(output_obj); output_list->emplace_back(output_obj);
(*output_idx)++; (*output_idx)++;
} }
@ -575,48 +581,76 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
return format; return format;
} }
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list,
const AnfNodePtr &anf_node) {
for (size_t i = 0; i < input_json.size(); i++) {
for (size_t m = 0; m < input_json[i].size(); m++) {
size_t size_i = 1;
if (input_json[i][m][kJValid] == false) {
std::string input_name = input_json[i][m][kJName];
MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
continue;
}
for (size_t j = 0; j < input_json[i][m][kJShape].size(); ++j) {
if (input_json[i][m][kJShape][j] == -1) {
auto input_max_shape = AnfAlgo::GetInputMaxShape(anf_node, i);
if (j >= input_max_shape.size()) {
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
}
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << input_max_shape[j];
size_i *= input_max_shape[j];
continue;
}
size_i *= static_cast<size_t>(input_json[i][m][kJShape][j]);
}
std::string dtype = input_json[i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte;
input_size_list->push_back(size_i);
}
}
}
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list,
const AnfNodePtr &anf_node) {
for (size_t i = 0; i < output_json.size(); i++) {
for (size_t m = 0; m < output_json[i].size(); m++) {
size_t size_i = 1;
if (output_json[i][m][kJValid] == false) {
std::string output_name = output_json[i][m][kJName];
MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
continue;
}
for (size_t j = 0; j < output_json[i][m][kJShape].size(); ++j) {
if (output_json[i][m][kJShape][j] == -1) {
auto output_max_shape = AnfAlgo::GetOutputMaxShape(anf_node, i);
if (j >= output_max_shape.size()) {
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
}
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j];
size_i *= output_max_shape[j];
continue;
}
size_i *= static_cast<size_t>(output_json[i][m][kJShape][j]);
}
std::string dtype = output_json[i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte;
output_size_list->push_back(size_i);
}
}
}
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list) { std::vector<size_t> *output_size_list, const AnfNodePtr &anf_node) {
if (input_size_list == nullptr || output_size_list == nullptr) { if (input_size_list == nullptr || output_size_list == nullptr) {
MS_LOG(ERROR) << "Input size or output size is nullptr"; MS_LOG(ERROR) << "Input size or output size is nullptr";
return false; return false;
} }
input_size_list->clear(); input_size_list->clear();
output_size_list->clear(); output_size_list->clear();
for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { GetInputSizeList(kernel_json[kJOpInfo][kJInputs], input_size_list, anf_node);
for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { GetOutputSizeList(kernel_json[kJOpInfo][kJOutputs], output_size_list, anf_node);
size_t size_i = 1;
if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) {
std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName];
MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
continue;
}
for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) {
size_i *= static_cast<size_t>(j);
}
std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte;
input_size_list->push_back(size_i);
}
}
for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) {
for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) {
size_t size_i = 1;
if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) {
std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName];
MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
continue;
}
for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) {
size_i *= static_cast<size_t>(j);
}
std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype];
size_t nbyte = tbe::GetDtypeNbyte(dtype);
size_i *= nbyte;
output_size_list->push_back(size_i);
}
}
return true; return true;
} }
@ -678,17 +712,18 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode
MS_EXCEPTION_IF_NULL(fusion_kernel_name); MS_EXCEPTION_IF_NULL(fusion_kernel_name);
// gen others // gen others
auto origin_type = AnfAlgo::GetCNodeName(cnode); auto origin_type = AnfAlgo::GetCNodeName(cnode);
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(origin_type, cnode);
// replace special op type for buffer fusion op // replace special op type for buffer fusion op
auto type = GetRealOpType(origin_type); auto type = GetRealOpType(origin_type);
(*compute_op_str)[kJtype] = type; (*compute_op_str)[kJtype] = type;
tbe::TbeAdapter::NormalizeFuncName(&type); auto kernel_name = op_info_ptr->kernel_name();
(*compute_op_str)[kJFuncName] = type; (*compute_op_str)[kJFuncName] = kernel_name;
(*compute_op_str)[kJModuleName] = std::string("impl.") + type; (*compute_op_str)[kJModuleName] = std::string("impl.") + type;
(*compute_op_str)[kJName] = cnode->fullname_with_scope(); (*compute_op_str)[kJName] = cnode->fullname_with_scope();
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode); (*compute_op_str)[kJPattern] = GetNodeFusionType(cnode);
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe"; (*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe";
(void)(*fusion_kernel_name).append("_"); (void)(*fusion_kernel_name).append("_");
(void)(*fusion_kernel_name).append(type); (void)(*fusion_kernel_name).append(kernel_name);
} }
void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) { void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) {
@ -952,7 +987,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
} }
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto node_name = AnfAlgo::GetCNodeName(cnode); auto node_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = OpLib::FindOp(node_name, kTBE); auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();

View File

@ -38,7 +38,7 @@ class TbeKernelBuild {
public: public:
static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list); std::vector<size_t> *output_size_list, const AnfNodePtr &anf_node);
// Ub Fuison // Ub Fuison
static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes, static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes,
const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_json, const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_json,

View File

@ -19,11 +19,14 @@
#include "runtime/rt.h" #include "runtime/rt.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "graphengine/inc/framework/ge_runtime/task_info.h" #include "graphengine/inc/framework/ge_runtime/task_info.h"
#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h"
#include "runtime/device/kernel_runtime.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>; using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
using tbe::KernelManager; using tbe::KernelManager;
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs, bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
const std::vector<mindspore::kernel::AddressPtr> &workspace, const std::vector<mindspore::kernel::AddressPtr> &workspace,
const std::vector<mindspore::kernel::AddressPtr> &outputs, void *stream_ptr) { const std::vector<mindspore::kernel::AddressPtr> &outputs, void *stream_ptr) {
@ -105,6 +108,49 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
return {task_info_ptr}; return {task_info_ptr};
} }
device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
// Get para_size from json
auto kernel_json_info = kernel_pack_->kernel_json_info();
auto op_para_size = kernel_json_info.op_para_size;
// Get stub_function
uint32_t block_dim = 1; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
// Generate args
std::vector<void *> runtime_args;
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args),
[](const AddressPtr &output) -> void * { return output->addr; });
if (!kernel_workspaces.empty()) {
(void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args),
[](const AddressPtr &addr) -> void * { return addr->addr; });
}
void *tiling_data_ptr = nullptr;
if (op_para_size > 0) {
auto ret = rtMalloc(&tiling_data_ptr, op_para_size, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtMalloc tiling data failed";
}
runtime_args.push_back(tiling_data_ptr);
}
auto executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(
stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args);
return executor;
}
vector<size_t> TbeKernelMod::GenParameters() { vector<size_t> TbeKernelMod::GenParameters() {
auto kernel_json_info = kernel_pack_->kernel_json_info(); auto kernel_json_info = kernel_pack_->kernel_json_info();
return kernel_json_info.parameters; return kernel_json_info.parameters;

View File

@ -42,6 +42,7 @@ class TbeKernelMod : public AscendKernelMod {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override; const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override; const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
std::vector<size_t> GenParameters() override; std::vector<size_t> GenParameters() override;
private: private:

View File

@ -15,13 +15,11 @@
*/ */
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
#include <memory> #include <memory>
#include <set> #include <set>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <string> #include <string>
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "backend/kernel_compiler/tbe/tbe_adapter.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
@ -29,6 +27,7 @@
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
@ -52,15 +51,18 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
// get size // get size
std::vector<size_t> input_size_list; std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list; std::vector<size_t> output_size_list;
(void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, anf_node);
// search cache // search cache
const std::string &json_name = creator.json_name(); const std::string &json_name = creator.json_name();
if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { auto IsDynamicShape = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node);
MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get()) &&
!IsDynamicShape) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " Use cached kernel, kernel json name:."
<< json_name;
continue; continue;
} }
// same op not need build, but need wait build finish to set kernel mode // same op not need build, but need wait build finish to set kernel mode
if (processed_kernel.find(json_name) != processed_kernel.end()) { if (processed_kernel.find(json_name) != processed_kernel.end() && !IsDynamicShape) {
build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list);
continue; continue;
} }
@ -72,8 +74,8 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
while (!build_manger->IsAllTaskFinish()) { while (!build_manger->IsAllTaskFinish()) {
int task_id = -1; int task_id = -1;
std::string task_result; std::string task_result;
std::string pre_build_result; std::string build_result;
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result);
if (!ret) { if (!ret) {
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
} }
@ -81,7 +83,7 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
if (task_result != "Success") { if (task_result != "Success") {
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
} }
(void)build_manger->TaskFinishProcess(task_id); (void)build_manger->TaskFinishProcess(task_id, build_result);
} }
return build_manger->GenSameOpKernelMod(); return build_manger->GenSameOpKernelMod();
} }
@ -93,7 +95,7 @@ void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNod
const std::vector<size_t> &output_size_list, int32_t scope_id) { const std::vector<size_t> &output_size_list, int32_t scope_id) {
MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id;
struct KernelBuildTaskInfo task_info; struct KernelBuildTaskInfo task_info;
task_info.node = anf_node.get(); task_info.node = anf_node;
task_info.json_name = json_name; task_info.json_name = json_name;
if (anf_node == nullptr) { if (anf_node == nullptr) {
task_info.processor = tbe::kProcessorAiCore; task_info.processor = tbe::kProcessorAiCore;
@ -111,7 +113,38 @@ bool ParallelBuildManager::IsAllTaskFinish() const {
return task_map_.empty(); return task_map_.empty();
} }
std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) {
auto task_iter = pre_task_map_.find(task_id);
if (task_iter == pre_task_map_.end()) {
MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id;
}
auto node = task_iter->second;
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
std::string start_flag = "fusion_pattern_start";
std::string end_flag = "fusion_pattern_end";
int start = pre_build_result.find(start_flag);
int end = pre_build_result.find(end_flag);
if (start != -1 && end != -1 && end >= start) {
std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size());
if (result.empty()) {
(void)pre_task_map_.erase(task_iter);
return;
}
transform(result.begin(), result.end(), result.begin(), ::toupper);
AnfAlgo::SetNodeAttr(kAttrFusionType, MakeValue(result), node);
FusionType fusion_type = tbe::GetFusionType(result);
builder->SetFusionType(fusion_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
(void)pre_task_map_.erase(task_iter);
}
std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t task_id, const std::string &build_ret,
bool set_kernel_mod) {
auto compile_info = ProcessBuildRetStr(build_ret);
MS_LOG(DEBUG) << "Tbe build ret:" << compile_info;
auto task_iter = task_map_.find(task_id); auto task_iter = task_map_.find(task_id);
if (task_iter == task_map_.end()) { if (task_iter == task_map_.end()) {
MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id;
@ -133,7 +166,9 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
task_iter->second.output_size_list, kernel_pack); task_iter->second.output_size_list, kernel_pack);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
if (set_kernel_mod) { if (set_kernel_mod) {
AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node.get());
AnfAlgo::SetNodeAttr(kAttrCompileInfo, MakeValue(compile_info), task_iter->second.node);
MS_LOG(DEBUG) << "Set Node Attr compile_info:" << compile_info;
} }
auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod);
(void)task_map_.erase(task_iter); (void)task_map_.erase(task_iter);
@ -145,7 +180,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &input_size_list,
const std::vector<size_t> &output_size_list) { const std::vector<size_t> &output_size_list) {
struct KernelBuildTaskInfo task_info; struct KernelBuildTaskInfo task_info;
task_info.node = anf_node.get(); task_info.node = anf_node;
task_info.json_name = json_name; task_info.json_name = json_name;
task_info.processor = tbe::GetProcessor(anf_node); task_info.processor = tbe::GetProcessor(anf_node);
task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end());
@ -156,7 +191,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node,
bool ParallelBuildManager::GenSameOpKernelMod() const { bool ParallelBuildManager::GenSameOpKernelMod() const {
for (const auto &task_info : same_op_list_) { for (const auto &task_info : same_op_list_) {
bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list,
task_info.output_size_list, task_info.node); task_info.output_size_list, task_info.node.get());
if (!ret) { if (!ret) {
MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache.";
return false; return false;
@ -212,5 +247,20 @@ void ParallelBuildManager::ResetTaskInfo() {
same_op_list_.clear(); same_op_list_.clear();
AscendKernelBuildClient::Instance().TbeReset(); AscendKernelBuildClient::Instance().TbeReset();
} }
std::string ParallelBuildManager::ProcessBuildRetStr(const std::string &build_result) {
std::string start_flag = "fusion_pattern_start";
std::string end_flag = "fusion_pattern_end";
int start = build_result.find(start_flag);
int end = build_result.find(end_flag);
if (start != -1 && end != -1 && end >= start) {
std::string result = build_result.substr(start + start_flag.size(), end - start - start_flag.size());
if (!result.empty()) {
return result;
}
}
return "";
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -31,7 +31,7 @@ namespace kernel {
bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes); bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
struct KernelBuildTaskInfo { struct KernelBuildTaskInfo {
AnfNode *node; AnfNodePtr node;
std::string processor; std::string processor;
std::string json_name; std::string json_name;
std::vector<size_t> input_size_list; std::vector<size_t> input_size_list;
@ -53,16 +53,21 @@ class ParallelBuildManager {
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list, const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
AnfNode *node) const; AnfNode *node) const;
bool IsAllTaskFinish() const; bool IsAllTaskFinish() const;
std::pair<int32_t, KernelModPtr> TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result);
std::pair<int32_t, KernelModPtr> TaskFinishProcess(int32_t task_id, const std::string &build_ret,
bool set_kernel_mod = true);
KernelModPtr GenKernelMod(const string &json_name, const string &processor, KernelModPtr GenKernelMod(const string &json_name, const string &processor,
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list, const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
const KernelPackPtr &kernel_pack) const; const KernelPackPtr &kernel_pack) const;
// Interactive with real backend, who could be implemented by Python. // Interactive with real backend, who could be implemented by Python.
int StartCompileOp(const nlohmann::json &kernel_json); static int StartCompileOp(const nlohmann::json &kernel_json);
bool WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result); static bool WaitOne(int *task_id, std::string *task_result, std::string *build_result);
void ResetTaskInfo(); void ResetTaskInfo();
private:
std::string ProcessBuildRetStr(const std::string &build_result);
private: private:
std::map<int32_t, AnfNodePtr> pre_task_map_; std::map<int32_t, AnfNodePtr> pre_task_map_;
std::map<int32_t, KernelBuildTaskInfo> task_map_; std::map<int32_t, KernelBuildTaskInfo> task_map_;

View File

@ -30,6 +30,7 @@
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/session/kernel_build_client.h" #include "backend/session/kernel_build_client.h"
namespace mindspore { namespace mindspore {
@ -54,7 +55,8 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
MS_EXCEPTION_IF_NULL(cnode_ptr_); MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_info_list_); MS_EXCEPTION_IF_NULL(kernel_info_list_);
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
auto op_info_ptr = OpLib::FindOp(node_name_, kTBE);
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
if (!op_info_ptr) { if (!op_info_ptr) {
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
return; return;
@ -81,6 +83,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
} }
// check support // check support
FilterInVaildKernelInfo(); FilterInVaildKernelInfo();
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
} }
void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {

View File

@ -23,6 +23,7 @@
#include "backend/kernel_compiler/kernel_query.h" #include "backend/kernel_compiler/kernel_query.h"
#include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -62,7 +63,7 @@ class KernelQuery {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
} }
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node);
if (op_info != nullptr) { if (op_info != nullptr) {
return op_info->is_ref(); return op_info->is_ref();
} }
@ -75,8 +76,8 @@ class OpFinder {
public: public:
OpFinder() = default; OpFinder() = default;
virtual ~OpFinder() = default; virtual ~OpFinder() = default;
virtual int GetOpRegisteredOutputNum(const std::string &op_name) { virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) {
auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
if (op_info == nullptr) { if (op_info == nullptr) {
return -1; return -1;
} }

View File

@ -46,6 +46,9 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
const EquivPtr &) const { const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);

View File

@ -36,7 +36,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
auto cnode = cur_node->cast<CNodePtr>(); auto cnode = cur_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::string op_name = AnfAlgo::GetCNodeName(cnode); std::string op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
// deal ref op // deal ref op
if (op_info != nullptr && op_info->is_ref()) { if (op_info != nullptr && op_info->is_ref()) {
auto ref_infos = op_info->ref_infos(); auto ref_infos = op_info->ref_infos();
@ -223,7 +223,7 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
DealBroadCastAsRef(graph, cnode); DealBroadCastAsRef(graph, cnode);
auto op_name = AnfAlgo::GetCNodeName(cnode); auto op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
if (op_info == nullptr || !op_info->is_ref()) { if (op_info == nullptr || !op_info->is_ref()) {
return nullptr; return nullptr;
} }

View File

@ -65,6 +65,9 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An
const EquivPtr &) const { const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// The real input begins with index 1. // The real input begins with index 1.

View File

@ -86,6 +86,9 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An
const EquivPtr &) const { const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().size() != kLayerNormGradInputNum) { if (cnode->inputs().size() != kLayerNormGradInputNum) {
return nullptr; return nullptr;

View File

@ -72,6 +72,9 @@ const BaseRef PackFission::DefinePattern() const {
const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// The real input begins with index 1. // The real input begins with index 1.

View File

@ -105,6 +105,9 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
if (graph == nullptr || node == nullptr) { if (graph == nullptr || node == nullptr) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, 2); CheckCNodeInputSize(cnode, 2);

View File

@ -174,6 +174,9 @@ const BaseRef SplitFission::DefinePattern() const {
const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// Check output num // Check output num

View File

@ -127,6 +127,9 @@ const BaseRef TopKSplit::DefinePattern() const {
const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>(); auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// set value node as topk's input // set value node as topk's input
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();

View File

@ -86,7 +86,7 @@ const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const
if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, &reg)) { if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, &reg)) {
return nullptr; return nullptr;
} }
int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); int output_num = op_finder_->GetOpRegisteredOutputNum(op_name, cnode);
// No need add output when it is not a tbe op. // No need add output when it is not a tbe op.
if (output_num == -1) { if (output_num == -1) {
return nullptr; return nullptr;

View File

@ -84,6 +84,9 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
MS_LOG(INFO) << "mul's second input is not addn"; MS_LOG(INFO) << "mul's second input is not addn";
return true; return true;
} }
if (AnfAlgo::IsDynamicShape(addn)) {
return true;
}
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0); std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0);
if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) {
MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]";

View File

@ -53,6 +53,9 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
// ReluV2's 2rd output is mask whose data type is uint8 // ReluV2's 2rd output is mask whose data type is uint8
TypeId mask_dtype = kNumberTypeUInt8; TypeId mask_dtype = kNumberTypeUInt8;
if (AnfAlgo::IsDynamicShape(relu)) {
return nullptr;
}
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
if (mask_shape.size() != 4) { if (mask_shape.size() != 4) {
MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; MS_LOG(DEBUG) << "relu's infer shape size not equal 4";

View File

@ -29,6 +29,9 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
if (!node->isa<ValueNode>()) { if (!node->isa<ValueNode>()) {
return false; return false;
} }
if (AnfAlgo::IsDynamicShape(node)) {
return false;
}
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
} }

View File

@ -85,6 +85,9 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
break; break;
} }
} }
if (AnfAlgo::IsDynamicShape(mul->input(lossscale_input_index))) {
return nullptr;
}
auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0);
if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) {
MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is "

View File

@ -45,6 +45,10 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
if (IsUsedByOthers(func_graph, in_reshape)) { if (IsUsedByOthers(func_graph, in_reshape)) {
return nullptr; return nullptr;
} }
if (AnfAlgo::IsDynamicShape(out_reshape) || AnfAlgo::IsDynamicShape(in_reshape)) {
return nullptr;
}
auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0); auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0);
if (kernel::IsSameShape(input_shape, output_shape)) { if (kernel::IsSameShape(input_shape, output_shape)) {

View File

@ -50,6 +50,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(transpose_cnode); MS_EXCEPTION_IF_NULL(transpose_cnode);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_cnode); MS_EXCEPTION_IF_NULL(reshape_cnode);
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
return nullptr;
}
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); std::vector<size_t> transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) {

View File

@ -50,6 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(reshape_cnode); MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(transpose_cnode); MS_EXCEPTION_IF_NULL(transpose_cnode);
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
return nullptr;
}
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {

View File

@ -26,6 +26,8 @@
#include "base/base_ref.h" #include "base/base_ref.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "base/core_ops.h" #include "base/core_ops.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "frontend/operator/ops.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "runtime/device/kernel_info.h" #include "runtime/device/kernel_info.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@ -394,6 +396,7 @@ bool IsNopNode(const AnfNodePtr &node) {
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) { context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false; return false;
} }
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
kFlattenGradOpName}; kFlattenGradOpName};

View File

@ -55,6 +55,10 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
continue; continue;
} }
} }
if (AnfAlgo::IsDynamicShape(cnode)) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
} }
return node; return node;

View File

@ -28,6 +28,7 @@
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/kernel_build_info.h" #include "backend/kernel_compiler/kernel_build_info.h"
#include "common/trans.h" #include "common/trans.h"
#include "abstract/param_validator.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
@ -42,12 +43,27 @@ namespace {
constexpr size_t kNopNodeInputSize = 2; constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1; constexpr size_t kNopNodeRealInputIndex = 1;
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
}
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) { std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape); MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t; std::vector<size_t> shape_size_t;
std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); if (IsShapeDynamic(shape)) {
if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int s) { return s >= 0; })) {
std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t), IntToSize);
} else {
MS_LOG(EXCEPTION) << "Invalid Max Shape";
}
} else {
std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
}
return shape_size_t; return shape_size_t;
} }
enum ShapeType { kMaxShape, kMinShape };
} // namespace } // namespace
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
@ -1206,19 +1222,6 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
return GetCNodeOutputPrecision(kernel_with_index.first); return GetCNodeOutputPrecision(kernel_with_index.first);
} }
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto has_attr = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
if (!has_attr) {
return false;
}
return AnfAlgo::GetNodeAttr<bool>(node, kAttrIsDynamicShape);
}
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->inputs().empty()) { if (node->inputs().empty()) {
@ -1252,5 +1255,96 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
} }
return true; return true;
} }
bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
if (!has_attr) {
return false;
}
return AnfAlgo::GetNodeAttr<bool>(node, attr);
}
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
}
void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
NotNull<std::vector<int64_t> *> dynamic_shape) {
for (auto size : shape) {
if (size == SIZE_MAX) {
dynamic_shape->push_back(-1);
} else {
dynamic_shape->push_back(SizeToLong(size));
}
}
}
std::vector<int> GetShapeFromSequeueShape(const abstract::SequeueShapePtr &sequeue_shape_ptr, size_t index,
ShapeType type) {
MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
auto shape_list = sequeue_shape_ptr->shape();
if (index >= shape_list.size()) {
MS_LOG(EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
}
auto shape = shape_list[index];
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
if (type == kMaxShape) {
return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
} else {
return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
}
} else {
MS_LOG(EXCEPTION) << "Invalid Shape Type In Shape List";
}
}
std::vector<int> AnfRuntimeAlgorithm::GetInputMaxShape(const AnfNodePtr &anf_node, size_t index) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
return GetOutputMaxShape(input_node_with_index.first, input_node_with_index.second);
}
std::vector<int> AnfRuntimeAlgorithm::GetInputMinShape(const AnfNodePtr &anf_node, size_t index) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
return GetOutputMinShape(input_node_with_index.first, input_node_with_index.second);
}
std::vector<int> AnfRuntimeAlgorithm::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto shape = anf_node->Shape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
} else if (shape->isa<abstract::SequeueShape>()) {
auto shape_ptr = shape->cast<abstract::SequeueShapePtr>();
return GetShapeFromSequeueShape(shape_ptr, index, kMaxShape);
} else {
MS_LOG(EXCEPTION) << "Invalid Shape Type";
}
}
std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
auto shape = anf_node->Shape();
MS_EXCEPTION_IF_NULL(shape);
if (shape->isa<abstract::Shape>()) {
auto shape_ptr = shape->cast<abstract::ShapePtr>();
return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
} else if (shape->isa<abstract::SequeueShape>()) {
auto shape_ptr = shape->cast<abstract::SequeueShapePtr>();
return GetShapeFromSequeueShape(shape_ptr, index, kMinShape);
} else {
MS_LOG(EXCEPTION) << "Invalid Shape Type";
}
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -221,6 +221,12 @@ class AnfRuntimeAlgorithm {
static bool IsDynamicShape(const AnfNodePtr &node); static bool IsDynamicShape(const AnfNodePtr &node);
static bool IsCondControlKernel(const CNodePtr &node); static bool IsCondControlKernel(const CNodePtr &node);
static bool IsIndependentNode(const CNodePtr &node); static bool IsIndependentNode(const CNodePtr &node);
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
static std::vector<int> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
}; };
} // namespace session } // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm; using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -127,6 +127,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "Start"; MS_LOG(INFO) << "Start";
std::vector<KernelGraphPtr> all_graphs; std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
// Update Graph Dynamic Shape Attr
UpdateGraphDynamicShapeAttr(NOT_NULL(root_graph));
root_graph->UpdateGraphDynamicAttr();
BackendOptimization(all_graphs); BackendOptimization(all_graphs);
// empty graph dont entry to backend // empty graph dont entry to backend
if (root_graph->execution_order().empty()) { if (root_graph->execution_order().empty()) {
@ -136,6 +139,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
InitRuntimeResource(); InitRuntimeResource();
return root_graph->graph_id(); return root_graph->graph_id();
} }
// create parameter for multiple branch // create parameter for multiple branch
std::set<KernelGraphPtr> memo; std::set<KernelGraphPtr> memo;
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));

View File

@ -1201,6 +1201,17 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
} }
} }
void KernelGraph::UpdateGraphDynamicAttr() {
for (const auto &cnode : execution_order_) {
if (AnfAlgo::IsDynamicShape(cnode)) {
MS_LOG(INFO) << "Update Graph Dynamic Attr";
is_dynamic_shape_ = true;
return;
}
}
is_dynamic_shape_ = false;
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() { KernelGraph::~KernelGraph() {

View File

@ -37,7 +37,13 @@ namespace session {
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>; using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
class KernelGraph : public FuncGraph { class KernelGraph : public FuncGraph {
public: public:
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) { KernelGraph()
: graph_id_(0),
start_label_(nullptr),
end_goto_(nullptr),
null_output_(false),
current_epoch_(0),
is_dynamic_shape_(false) {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>(); inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {}; execution_order_ = {};
executable_ = true; executable_ = true;
@ -161,6 +167,7 @@ class KernelGraph : public FuncGraph {
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
child_graph_result_ = child_graph_result; child_graph_result_ = child_graph_result;
} }
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) { void InsertTupleParameterToMakeTupleMap(const AnfNodePtr &param, const AnfNodePtr &make_tuple) {
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
return; return;
@ -176,6 +183,9 @@ class KernelGraph : public FuncGraph {
} }
void RemoveNodeFromGraph(const AnfNodePtr &node); void RemoveNodeFromGraph(const AnfNodePtr &node);
void UpdateGraphDynamicAttr();
bool is_dynamic_shape() const { return is_dynamic_shape_; }
private: private:
// remove value node form graph // remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
@ -247,10 +257,10 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
uint32_t current_epoch_; uint32_t current_epoch_;
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
std::set<AnfNodePtr> visited_nodes_; std::set<AnfNodePtr> visited_nodes_;
std::map<AnfNodePtr, AnfNodePtr> edge_to_; std::map<AnfNodePtr, AnfNodePtr> edge_to_;
std::stack<AnfNodePtr> loop_nodes_; std::stack<AnfNodePtr> loop_nodes_;
bool is_dynamic_shape_;
}; };
} // namespace session } // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;

View File

@ -35,6 +35,7 @@
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "utils/utils.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/worker.h" #include "ps/worker.h"
#include "ps/common.h" #include "ps/common.h"
@ -1405,6 +1406,97 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs); executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
} }
bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
}
bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
auto base_shape = anf_node_ptr->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid bash shape ptr, node:" << anf_node_ptr->fullname_with_scope();
return false;
}
if (base_shape->isa<abstract::Shape>()) {
if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
for (size_t i = 0; i < tuple_shape->size(); ++i) {
auto b_shp = (*tuple_shape)[i];
if (!b_shp->isa<abstract::Shape>()) {
continue;
}
if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
return true;
}
}
}
return false;
}
bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
MS_EXCEPTION_IF_NULL(anf_node_ptr);
auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
for (size_t i = 0; i < input_num; ++i) {
auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
auto input = input_with_index.first;
auto index = input_with_index.second;
MS_EXCEPTION_IF_NULL(input);
auto base_shape = input->Shape();
if (base_shape == nullptr) {
MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope();
continue;
}
if (base_shape->isa<abstract::Shape>()) {
if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
return true;
}
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (index >= tuple_shape->size()) {
MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index
<< " and tuple_shape size:" << tuple_shape->size();
continue;
}
auto b_shp = (*tuple_shape)[index];
if (!b_shp->isa<abstract::Shape>()) {
continue;
}
if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
return true;
}
}
}
return false;
}
void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph) {
for (const auto &cnode : root_graph->execution_order()) {
auto output_dynamic = IsNodeOutputDynamicShape(NOT_NULL(cnode));
auto input_dynamic = IsNodeInputDynamicShape(NOT_NULL(cnode));
if (output_dynamic || input_dynamic) {
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
}
if (output_dynamic) {
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
}
if (input_dynamic) {
AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
}
}
}
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
if (!ps::Util::IsRoleOfWorker()) { if (!ps::Util::IsRoleOfWorker()) {

View File

@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph); void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter); void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter);
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;

View File

@ -713,5 +713,16 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
return eval_result; return eval_result;
} }
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(prim);
auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap();
auto ret = prim_eval_implement_map.find(prim);
if (ret == prim_eval_implement_map.end()) {
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
<< " primitive type:" << prim->type_name();
}
return ret->second.impl_(nullptr, prim, args_spec_list);
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -302,6 +302,8 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
} }
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "kernel_info.cc" "executor/dynamic_kernel.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
) )
if (ENABLE_GPU) if (ENABLE_GPU)

View File

@ -372,7 +372,7 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con
// get size // get size
std::vector<size_t> input_size_list; std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list; std::vector<size_t> output_size_list;
(void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); (void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, nullptr);
std::string json_name = kernel_json[op_info_str][kernel_name_str]; std::string json_name = kernel_json[op_info_str][kernel_name_str];
// op build // op build
if (constructed_kernel.find(json_name) == constructed_kernel.end()) { if (constructed_kernel.find(json_name) == constructed_kernel.end()) {
@ -382,15 +382,15 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con
while (!build_manager->IsAllTaskFinish()) { while (!build_manager->IsAllTaskFinish()) {
int task_id = -1; int task_id = -1;
std::string task_result; std::string task_result;
std::string pre_build_result; std::string build_result;
auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result); auto ret = build_manager->WaitOne(&task_id, &task_result, &build_result);
if (!ret) { if (!ret) {
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
} }
if (task_result != "Success") { if (task_result != "Success") {
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
} }
(void)build_manager->TaskFinishProcess(task_id, false); (void)build_manager->TaskFinishProcess(task_id, build_result, false);
} }
constructed_kernel.insert(json_name); constructed_kernel.insert(json_name);
// search cache // search cache

View File

@ -46,12 +46,22 @@
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
#endif #endif
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include "runtime/device/ascend/executor/executor_callback.h"
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
#include "profiler/device/ascend/ascend_profiling.h"
#include "profiler/device/ascend/profiling_context.h"
#include "profiler/device/ascend/rt_callback_manager.h"
using ge::model_runner::ModelRunner; using ge::model_runner::ModelRunner;
using mindspore::device::ascend::ProfilingManager; using mindspore::device::ascend::ProfilingManager;
using mindspore::device::ascend::ProfilingUtils; using mindspore::device::ascend::ProfilingUtils;
using mindspore::device::ascend::tasksink::TaskGenerator; using mindspore::device::ascend::tasksink::TaskGenerator;
using mindspore::kernel::tbe::TbeUtils; using mindspore::kernel::tbe::TbeUtils;
using mindspore::profiler::ascend::AscendProfiler;
using mindspore::profiler::ascend::CallbackManager;
using mindspore::profiler::ascend::GetTid;
using mindspore::profiler::ascend::kCallback;
using std::vector; using std::vector;
constexpr uint32_t kTupleTaskId = 0; constexpr uint32_t kTupleTaskId = 0;
@ -135,6 +145,8 @@ void AscendKernelRuntime::ClearGraphModelMap() {
// tell users which dump kernel name not used // tell users which dump kernel name not used
DumpJsonParser::GetInstance().PrintUnusedKernel(); DumpJsonParser::GetInstance().PrintUnusedKernel();
graph_dynamic_kernel_map_.clear();
for (auto &iter : graph_model_map_) { for (auto &iter : graph_model_map_) {
MS_LOG(INFO) << "Ge UnloadModel " << iter.first; MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
auto ret = ModelRunner::Instance().UnloadModel(iter.first); auto ret = ModelRunner::Instance().UnloadModel(iter.first);
@ -160,6 +172,13 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
} }
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " dynamic kernels";
if (auto dynamic_kernel_iter = graph_dynamic_kernel_map_.find(graph_id);
dynamic_kernel_iter != graph_dynamic_kernel_map_.end()) {
MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel";
graph_dynamic_kernel_map_.erase(dynamic_kernel_iter);
}
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
@ -233,6 +252,7 @@ bool AscendKernelRuntime::Init() {
InnerSetContext(); InnerSetContext();
return true; return true;
} }
OpTilingCalculater::GetInstance().Init();
// Start up profiling before rtSetDevice // Start up profiling before rtSetDevice
bool ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); bool ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
if (!ret) { if (!ret) {
@ -342,6 +362,11 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) { if (!is_task_sink) {
return true; return true;
} }
// Do HcomExecutorInitialize
if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) {
MS_LOG(ERROR) << "Init Hccl Executor Failed";
return false;
}
if (!GenTask(graph)) { if (!GenTask(graph)) {
return false; return false;
} }
@ -351,8 +376,35 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
return true; return true;
} }
bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "GenDynamicKernel start";
auto cnode_list = graph->execution_order();
std::vector<DynamicKernelPtr> dynamic_kernels;
for (const auto &cnode : cnode_list) {
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel";
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_);
MS_EXCEPTION_IF_NULL(dynamic_kernel);
dynamic_kernel->Initialize();
dynamic_kernels.emplace_back(dynamic_kernel);
}
auto ret = graph_dynamic_kernel_map_.try_emplace(graph->graph_id(), dynamic_kernels);
if (!ret.second) {
MS_LOG(ERROR) << "Graph:" << graph->graph_id() << " already generator executor";
return false;
}
MS_LOG(INFO) << "GenDynamicKernel end";
return true;
}
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
InnerSetContext(); InnerSetContext();
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Generate Dynamic kernel";
return GenDynamicKernel(graph);
}
if (graph == nullptr) { if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
} }
@ -407,6 +459,11 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
InnerSetContext(); InnerSetContext();
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
return true;
}
if (graph == nullptr) { if (graph == nullptr) {
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
} }
@ -520,9 +577,70 @@ bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, De
return ret; return ret;
} }
bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph->graph_id();
auto iter = graph_dynamic_kernel_map_.find(graph->graph_id());
if (iter == graph_dynamic_kernel_map_.end()) {
MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Not Found! Please generator executor first";
return false;
}
// Profiling Init
auto &async_profiler = AscendProfiler::GetInstance();
auto &rt_callback = CallbackManager::GetInstance(stream_);
rt_callback.Init();
auto dynamic_kernels = iter->second;
for (const auto &dynamic_kernel : dynamic_kernels) {
if (dynamic_kernel->have_depends()) {
MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream";
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
}
}
if (dynamic_kernel->is_dynamic_shape()) {
ExecutorCallback::GetInstance().Consume();
dynamic_kernel->InferShape();
dynamic_kernel->UpdateArgs();
}
// Enable profiling trace point start
rt_callback.RegisterCallback(
[&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] start"); });
dynamic_kernel->Execute();
// Enable profiling trace point end
rt_callback.RegisterCallback(
[&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] end"); });
ExecutorCallback::GetInstance().RegistCallback([&dynamic_kernel] { dynamic_kernel->PostExecute(); });
}
if (!SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
}
ExecutorCallback::GetInstance().Consume();
rt_callback.Destroy();
async_profiler.Dump(std::cout);
async_profiler.Reset();
return true;
}
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
InnerSetContext(); InnerSetContext();
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (graph->is_dynamic_shape()) {
MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
return RunDynamicKernelAsync(graph);
}
MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
@ -657,7 +775,12 @@ bool AscendKernelRuntime::DestroyHccl() {
MS_LOG(INFO) << "Hccl is not enable, no need to close."; MS_LOG(INFO) << "Hccl is not enable, no need to close.";
return true; return true;
} }
// Dynamic Shape Hccl Finalize
if (!HcclExecutorManager::GetInstance().Finalize()) {
MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed";
}
HcclResult res = hcom_destroy(); HcclResult res = hcom_destroy();
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Hccl destroy failed"; MS_LOG(ERROR) << "Hccl destroy failed";
return false; return false;

View File

@ -40,6 +40,8 @@ class AscendKernelRuntime : public KernelRuntime {
bool Init() override; bool Init() override;
bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; bool LoadData(session::KernelGraph *graph, Debugger *debugger) override;
bool GenTask(const session::KernelGraph *graph); bool GenTask(const session::KernelGraph *graph);
bool GenDynamicKernel(const session::KernelGraph *graph) override;
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
bool LoadTask(const session::KernelGraph *graph); bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph); bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph, bool is_task_sink) override; bool Load(session::KernelGraph *graph, bool is_task_sink) override;

View File

@ -34,7 +34,7 @@ const uint32_t kHcomMaxTask = 5;
const uint32_t kCommonMaxTask = 350; const uint32_t kCommonMaxTask = 350;
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
if (IsTaskSink()) { if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
Reset(); Reset();
SetLoopSink(); SetLoopSink();
ReorderIndependentOrders(graph_ptr); ReorderIndependentOrders(graph_ptr);

View File

@ -24,7 +24,7 @@
#include "runtime/mem.h" #include "runtime/mem.h"
#include "runtime/kernel.h" #include "runtime/kernel.h"
#include "runtime/rt_model.h" #include "runtime/rt_model.h"
#include "runtime/device/ascend/dump/ge_dump.h" #include "runtime/device/ascend/ge_types_convert.h"
#include "proto/op_mapping_info.pb.h" #include "proto/op_mapping_info.pb.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "debug/data_dump/dump_json_parser.h" #include "debug/data_dump/dump_json_parser.h"
@ -369,13 +369,13 @@ void DataDumper::DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull<ai
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i);
aicpu::dump::Output output; aicpu::dump::Output output;
output.set_data_type(GetGeDataType(data_type)); output.set_data_type(GeTypesConvert::GetGeDataType(data_type));
output.set_format(GetGeFormat(output_format, output_shape.size())); output.set_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_EXCEPTION_IF_NULL(output.mutable_shape()); MS_EXCEPTION_IF_NULL(output.mutable_shape());
for (auto dim : output_shape) { for (auto dim : output_shape) {
output.mutable_shape()->add_dim(dim); output.mutable_shape()->add_dim(dim);
} }
output.set_original_output_format(GetGeFormat(output_format, output_shape.size())); output.set_original_output_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + offset); output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + offset);
// device address data size // device address data size
auto address = AnfAlgo::GetOutputAddr(kernel, i); auto address = AnfAlgo::GetOutputAddr(kernel, i);
@ -409,8 +409,8 @@ void DataDumper::DumpKernelInput(const CNodePtr &kernel, void *args, NotNull<aic
} }
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
input.set_data_type(GetGeDataType(output_type)); input.set_data_type(GeTypesConvert::GetGeDataType(output_type));
input.set_format(GetGeFormat(output_format, output_shape.size())); input.set_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_EXCEPTION_IF_NULL(input.mutable_shape()); MS_EXCEPTION_IF_NULL(input.mutable_shape());
for (auto dim : output_shape) { for (auto dim : output_shape) {
input.mutable_shape()->add_dim(dim); input.mutable_shape()->add_dim(dim);

View File

@ -0,0 +1,182 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h"
#include <regex>
#include <algorithm>
#include <memory>
#include "framework/common/debug/log.h"
#include "utils/log_adapter.h"
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include "register/op_tiling.h"
#include "utils/convert_utils_base.h"
#include "utils/ms_context.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "common/trans.h"
namespace mindspore {
namespace device {
namespace ascend {
AiCoreDynamicKernel::~AiCoreDynamicKernel() {
if (tiling_data_ptr_ != nullptr) {
auto ret = rtFree(tiling_data_ptr_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtFree tiling_data_ptr_ failed";
}
}
}
void AiCoreDynamicKernel::Execute() {
if (stream_ == nullptr) {
MS_LOG(EXCEPTION) << "stream_ptr should not be nullptr.";
}
MS_LOG(INFO) << "Start Execute node:" << cnode_ptr_->fullname_with_scope();
rtL2Ctrl_t *l2ctrl = nullptr;
auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size());
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
}
MS_LOG(INFO) << "End Execute node:" << cnode_ptr_->fullname_with_scope();
}
std::string ReplaceInvalidJsonStr(const std::string &str) {
auto ret = std::regex_replace(str, std::regex("100000000"), R"("100000000")");
ret = std::regex_replace(ret, std::regex("100000001"), R"("100000001")");
ret = std::regex_replace(ret, std::regex("100000002"), R"("100000002")");
ret = std::regex_replace(ret, std::regex("True"), R"(true)");
ret = std::regex_replace(ret, std::regex("False"), R"(false)");
return ret;
}
void AiCoreDynamicKernel::ParseCompileJson() {
if (!AnfAlgo::IsDynamicShape(cnode_ptr_)) {
return;
}
if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, cnode_ptr_)) {
MS_LOG(EXCEPTION) << "Get compile_info failed";
}
auto compile_info_attr = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrCompileInfo);
std::replace(compile_info_attr.begin(), compile_info_attr.end(), '\'', '\"');
compile_info_attr = ReplaceInvalidJsonStr(compile_info_attr);
MS_LOG(INFO) << "Get compile_info:" << compile_info_attr;
try {
compile_info_json_ = std::make_shared<nlohmann::json>(nlohmann::json::parse(compile_info_attr));
} catch (nlohmann::json::parse_error &e) {
MS_LOG(EXCEPTION) << "parse json failed, error:" << e.what();
}
if (AnfAlgo::HasNodeAttr(kAttrFusionType, cnode_ptr_)) {
auto fusion_type = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrFusionType);
MS_LOG(INFO) << "Get fusion_type:" << fusion_type;
(*compile_info_json_)["_pattern"] = fusion_type;
}
}
void AiCoreDynamicKernel::Initialize() {
DynamicKernel::Initialize();
ParseCompileJson();
}
void AiCoreDynamicKernel::UpdateArgs() {
ComputeTiling();
if (!CopyTilingToDevice()) {
MS_LOG(EXCEPTION) << "Copy tiling to device failed";
}
AllocateWorkspace();
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
runtime_args_.clear();
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_),
[](const AddressPtr &input) -> void * { return input->addr; });
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_),
[](const AddressPtr &output) -> void * { return output->addr; });
// Update workspace
if (!workspace_addr_.empty()) {
(void)std::transform(std::begin(workspace_addr_), std::end(workspace_addr_), std::back_inserter(runtime_args_),
[](const DeviceAddressPtr &address_ptr) -> void * { return address_ptr->GetMutablePtr(); });
}
if (is_dynamic_shape_ && !tiling_data_.empty() && tiling_data_ptr_ != nullptr) {
runtime_args_.push_back(tiling_data_ptr_);
}
}
void AiCoreDynamicKernel::ComputeTiling() {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_LOG(INFO) << "Start compute tiling of:" << cnode_ptr_->fullname_with_scope();
optiling::OpRunInfo op_run_info;
OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode_ptr_), NOT_NULL(compile_info_json_),
depend_tensor_map_, NOT_NULL(&op_run_info));
block_dim_ = op_run_info.block_dim;
workspaces_size_ = op_run_info.workspaces;
tiling_data_ = op_run_info.tiling_data.str();
}
void AiCoreDynamicKernel::AllocateWorkspace() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto runtime_instance = KernelRuntimeManager::Instance().GetSingleKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
workspace_addr_.clear();
for (auto size : workspaces_size_) {
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size);
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
if (device_ptr == nullptr) {
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed";
}
workspace_addr_.emplace_back(device_address_ptr);
}
}
bool AiCoreDynamicKernel::CopyTilingToDevice() {
if (tiling_data_.size() > op_para_size_) {
MS_LOG(EXCEPTION) << "compute tiling size:" << tiling_data_.size()
<< " larger than tbe build op_para_size:" << op_para_size_;
}
if (tiling_data_.empty() || tiling_data_ptr_ == nullptr) {
MS_LOG(INFO) << "tiling size is 0, skip rtMemcpyAsync";
return true;
}
auto ret = rtMemcpyAsync(tiling_data_ptr_, tiling_data_.size(), tiling_data_.c_str(), tiling_data_.size(),
RT_MEMCPY_HOST_TO_DEVICE_EX, stream_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "tiling rtMemcpyAsync failed, ret:" << ret;
}
return true;
}
void AiCoreDynamicKernel::PostExecute() {}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_
#include <vector>
#include <map>
#include <string>
#include <memory>
#include "nlohmann/json.hpp"
#include "ir/tensor.h"
#include "runtime/device/device_address.h"
#include "mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h"
namespace mindspore {
namespace device {
namespace ascend {
class AiCoreDynamicKernel : public DynamicKernel {
public:
AiCoreDynamicKernel(const void *stub_fubc, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size,
void *stream, const CNodePtr &cnode_ptr, const std::vector<void *> &runtime_args)
: DynamicKernel(stream, cnode_ptr),
stub_func_(stub_fubc),
block_dim_(block_dim),
tiling_data_ptr_(tiling_data_ptr),
op_para_size_(op_para_size),
runtime_args_(runtime_args) {}
~AiCoreDynamicKernel() override;
void Execute() override;
void UpdateArgs() override;
void Initialize() override;
void PostExecute() override;
protected:
void AllocateWorkspace();
void ParseCompileJson();
private:
const void *stub_func_;
uint32_t block_dim_;
void *tiling_data_ptr_; // device ptr
uint32_t op_para_size_; // size of tiling_data_ptr_
std::vector<void *> runtime_args_;
std::string tiling_data_;
std::vector<int64_t> workspaces_size_;
std::vector<DeviceAddressPtr> workspace_addr_;
std::shared_ptr<nlohmann::json> compile_info_json_;
void ComputeTiling();
bool CopyTilingToDevice();
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,204 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h"
#include <vector>
#include <memory>
#include <algorithm>
#include "runtime/mem.h"
#include "runtime/kernel.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
#include "runtime/device/ascend/executor/executor_callback.h"
namespace mindspore {
namespace device {
namespace ascend {
AiCpuDynamicKernel::~AiCpuDynamicKernel() {
// free dev ptr
if (ext_info_addr_dev_ == nullptr) {
return;
}
auto ret = rtFree(ext_info_addr_dev_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtFree failed";
}
}
void AiCpuDynamicKernel::UpdateArgs() {
if (!UpdateInputOutputAddr()) {
MS_LOG(EXCEPTION) << "Update input output failed";
}
if (is_dynamic_shape_ && !UpdateExtInfo()) {
MS_LOG(EXCEPTION) << "Update ExtInfo failed";
}
}
void AiCpuDynamicKernel::Execute() {
MS_LOG(INFO) << "Execute AiCpuDynamicKerenl Start";
auto ret = rtCpuKernelLaunchWithFlag(
reinterpret_cast<const void *>(so_name_.c_str()), reinterpret_cast<const void *>(kernel_name_.c_str()), 1,
reinterpret_cast<const void *>(args_.data()), args_.size(), nullptr, stream_, RT_KERNEL_DEFAULT);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtCpuKernelLaunchWithFlag Failed";
}
}
void AiCpuDynamicKernel::Initialize() {
// is dynamic
MS_LOG(INFO) << "Initialize node:" << cnode_ptr_->fullname_with_scope();
DynamicKernel::Initialize();
input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_);
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
// Parse aicpu ext info
if (is_dynamic_shape_) {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
ext_info_handler_ =
std::make_shared<AicpuExtInfoHandler>(cnode_ptr_->fullname_with_scope(), input_num_, output_num_, DEPEND_COMPUTE);
ext_info_handler_->Parse(ext_info_data_);
}
if (ext_info_data_.empty()) {
MS_LOG(INFO) << "No need to copy to device, ext_info_data_ is empty. ";
return;
}
// Allocate ext info addr in device
auto ret = rtMalloc(&ext_info_addr_dev_, ext_info_data_.size(), RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtMalloc ext_info_addr_dev_ failed";
}
ext_info_size_ = ext_info_data_.size();
ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_data_.data(), ext_info_data_.size(),
RT_MEMCPY_HOST_TO_DEVICE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtMemcpy ext_info_addr_dev_ failed";
}
auto aicpu_param_head = reinterpret_cast<kernel::AicpuParamHead *>(args_.data());
aicpu_param_head->extInfoLength = ext_info_size_;
aicpu_param_head->extInfoAddr = reinterpret_cast<uint64_t>(ext_info_addr_dev_);
}
bool AiCpuDynamicKernel::UpdateInputOutputAddr() {
std::vector<uint64_t> io_addrs;
io_addrs.reserve(input_num_ + output_num_);
for (size_t i = 0; i < input_num_; ++i) {
auto input_addr = AnfAlgo::GetPrevNodeOutputAddr(cnode_ptr_, i);
io_addrs.emplace_back(reinterpret_cast<uintptr_t>(input_addr->GetMutablePtr()));
}
for (size_t i = 0; i < output_num_; ++i) {
auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, i);
io_addrs.emplace_back(reinterpret_cast<uintptr_t>(output_addr->GetMutablePtr()));
}
if (args_.empty()) {
MS_LOG(ERROR) << "args_ is empty";
return false;
}
auto io_ptr = args_.data() + sizeof(kernel::AicpuParamHead);
auto ret =
memcpy_s(io_ptr, args_.size() - sizeof(kernel::AicpuParamHead), &io_addrs[0], sizeof(uint64_t) * io_addrs.size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy input output addr failed";
}
return true;
}
bool AiCpuDynamicKernel::UpdateExtInfo() {
MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " start";
if (input_num_ == 0 && output_num_ == 0) {
MS_LOG(INFO) << "Node:" << cnode_ptr_->fullname_with_scope() << " no need to update output shape";
return true;
}
for (size_t i = 0; i < input_num_; ++i) {
ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_));
}
if (unknow_type_ != DEPEND_COMPUTE) {
for (size_t i = 0; i < output_num_; ++i) {
ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_));
}
}
auto ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_handler_->GetExtInfo(),
ext_info_handler_->GetExtInfoLen(), RT_MEMCPY_HOST_TO_DEVICE);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "UpdateExtInfo rtMemcpy failed";
return false;
}
MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " end";
return true;
}
bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() {
if (input_num_ == 0) {
MS_LOG(WARNING) << "input num is 0";
return true;
}
MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start";
auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_,
ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "rtMemcpy output shape failed";
return false;
}
MS_LOG(INFO) << "rtMemcpy from device to host success";
std::vector<TypeId> type_ids;
std::vector<std::vector<size_t>> shapes;
for (size_t i = 0; i < output_num_; ++i) {
MS_LOG(INFO) << "Get output:" << output_num_ << " Shape";
std::vector<int64_t> shape;
TypeId type_id;
ext_info_handler_->GetOutputShapeAndType(i, NOT_NULL(&shape), NOT_NULL(&type_id));
for (auto x : shape) {
MS_LOG(INFO) << "Update output:" << i << " shape:" << x;
}
type_ids.emplace_back(type_id);
std::vector<size_t> size_t_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(size_t_shape), LongToSize);
shapes.emplace_back(size_t_shape);
}
AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode_ptr_.get());
return true;
}
void AiCpuDynamicKernel::PostExecute() {
MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute";
if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ == DEPEND_COMPUTE) {
MS_LOG(INFO) << "Update aicpu kernel output shape from ext_info";
UpdateOutputShapeFromExtInfo();
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_
#include <string>
#include <memory>
#include "runtime/device/executor/dynamic_kernel.h"
#include "ir/anf.h"
#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h"
namespace mindspore {
namespace device {
namespace ascend {
class AiCpuDynamicKernel : public DynamicKernel {
public:
AiCpuDynamicKernel(void *stream, const CNodePtr &cnode_ptr, const std::string &args, const std::string &ext_info_data,
const std::string &so_name, const std::string &kernel_name)
: DynamicKernel(stream, cnode_ptr),
args_(args),
ext_info_data_(ext_info_data),
so_name_(so_name),
kernel_name_(kernel_name),
ext_info_handler_(nullptr),
ext_info_addr_dev_(nullptr),
ext_info_size_(0),
input_num_(0),
output_num_(0),
unknow_type_(DEPEND_COMPUTE) {}
~AiCpuDynamicKernel() override;
void UpdateArgs() override;
void Execute() override;
void Initialize() override;
void PostExecute() override;
// Get Compute Shape from ExtInfo
bool UpdateOutputShapeFromExtInfo();
private:
std::string args_;
std::string ext_info_data_;
std::string so_name_;
std::string kernel_name_;
std::shared_ptr<AicpuExtInfoHandler> ext_info_handler_;
void *ext_info_addr_dev_;
size_t ext_info_size_;
size_t input_num_;
size_t output_num_;
UnknowShapeOpType unknow_type_;
bool UpdateInputOutputAddr();
bool UpdateExtInfo();
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,218 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h"
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
namespace mindspore {
namespace device {
namespace ascend {
namespace {
// if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end.
constexpr int64_t kDimEndFlag = INT64_MIN;
} // namespace
bool AicpuExtInfoHandler::Parse(const std::string &ext_info) {
MS_LOG(INFO) << "Parse Node:" << node_name_ << " start";
if (ext_info.empty()) {
MS_LOG(ERROR) << "Node:" << node_name_ << " ext_info is empty";
return false;
}
ext_info_len_ = ext_info.size();
ext_info_.reset(new (std::nothrow) uint8_t[ext_info_len_]);
MS_EXCEPTION_IF_NULL(ext_info_);
(void)memcpy_s(ext_info_.get(), ext_info_len_, ext_info.c_str(), ext_info.size());
input_shape_and_type_.clear();
output_shape_and_type_.clear();
auto ext_info_data = ext_info_.get();
size_t offset = 0;
while (offset + sizeof(AicpuExtInfo) <= ext_info_len_) {
auto aicpu_ext_info = reinterpret_cast<AicpuExtInfo *>(ext_info_data + offset);
MS_EXCEPTION_IF_NULL(aicpu_ext_info);
switch (aicpu_ext_info->infoType) {
case kernel::FWK_ADPT_EXT_SHAPE_TYPE:
if (!ParseExtShapeType(aicpu_ext_info)) {
MS_LOG(EXCEPTION) << "Parse ext shape type failed.";
}
break;
case kernel::FWK_ADPT_EXT_INPUT_SHAPE:
if (!ParseExtInputShape(aicpu_ext_info)) {
MS_LOG(EXCEPTION) << "Parse ext input shape failed.";
}
break;
case kernel::FWK_ADPT_EXT_OUTPUT_SHAPE:
if (!ParseExtOutputShape(aicpu_ext_info)) {
MS_LOG(EXCEPTION) << "Parse ext output shape failed.";
}
break;
default:
MS_LOG(INFO) << "Ignore Node:" << node_name_ << " infoType:" << aicpu_ext_info->infoType
<< " infoLen:" << aicpu_ext_info->infoLen;
break;
}
offset += sizeof(AicpuExtInfo);
offset += aicpu_ext_info->infoLen;
}
if (offset != ext_info_len_) {
MS_LOG(EXCEPTION) << "Node:" << node_name_ << " ext_info format error, parse not reach end, offset=" << offset
<< ", ext_info_len" << ext_info_len_;
}
MS_LOG(INFO) << "Node:" << node_name_ << " parse ext info end.";
return true;
}
bool AicpuExtInfoHandler::ParseExtShapeType(AicpuExtInfo *aicpu_ext_info) {
if (aicpu_ext_info->infoLen != sizeof(int32_t)) {
MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as infoLen must be " << sizeof(int32_t)
<< " but got:" << aicpu_ext_info->infoLen;
return false;
}
auto type = reinterpret_cast<const int32_t *>(aicpu_ext_info->infoMsg);
if (*type != unknown_type_) {
MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as need:" << unknown_type_
<< " but got:" << *type;
}
MS_LOG(INFO) << "Node:" << node_name_ << "parse ext shape type success infoLen=" << aicpu_ext_info->infoLen;
return true;
}
bool AicpuExtInfoHandler::ParseExtInputShape(AicpuExtInfo *aicpu_ext_info) {
auto need_len = input_num_ * sizeof(AicpuShapeAndType);
if (aicpu_ext_info->infoLen != need_len) {
MS_LOG(ERROR) << "Node:" << node_name_
<< " parse ext input shape failed as aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen
<< " and need_len:" << need_len;
}
auto input = reinterpret_cast<AicpuShapeAndType *>(aicpu_ext_info->infoMsg);
for (uint32_t index = 0; index < input_num_; ++index) {
input_shape_and_type_.emplace_back(&input[index]);
}
MS_LOG(INFO) << "Node:" << node_name_.c_str() << " parse ext input shape success infoLen=" << aicpu_ext_info->infoLen;
return true;
}
bool AicpuExtInfoHandler::ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info) {
auto need_len = output_num_ * sizeof(AicpuShapeAndType);
if (aicpu_ext_info->infoLen != need_len) {
MS_LOG(INFO) << "Node:" << node_name_
<< " parse ext output shape failed, aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen
<< " need_len:" << need_len;
return false;
}
auto output = reinterpret_cast<AicpuShapeAndType *>(aicpu_ext_info->infoMsg);
for (uint32_t index = 0; index < output_num_; ++index) {
output_shape_and_type_.emplace_back(&output[index]);
}
MS_LOG(INFO) << "Node:" << node_name_ << " parse ext output shape success infoLen=" << aicpu_ext_info->infoLen;
return true;
}
bool AicpuExtInfoHandler::UpdateInputShapeAndType(uint32_t input_index, const NotNull<AnfNodePtr> &anf_node) {
if (input_index >= input_num_) {
MS_LOG(ERROR) << "input_index=" << input_index << " >= input_num_:" << input_num_;
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
auto data_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index);
std::vector<int64_t> tmp_shape;
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(tmp_shape), SizeToLong);
return UpdateShapeAndType(tmp_shape, data_type, NOT_NULL(input_shape_and_type_[input_index]));
}
bool AicpuExtInfoHandler::UpdateOutputShapeAndType(uint32_t output_index, const NotNull<AnfNodePtr> &anf_node) {
if (output_index >= output_num_) {
MS_LOG(ERROR) << "output_index:" << output_index << " >= output_num_:" << output_num_;
return false;
}
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
auto max_shape = AnfAlgo::GetOutputMaxShape(anf_node, output_index);
if (shape.size() != max_shape.size()) {
MS_LOG(ERROR) << "shape size != max_shape size";
return true;
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < max_shape.size() && shape[i] == SIZE_MAX) {
MS_LOG(INFO) << "Node:" << node_name_ << " update shape from SIZE_MAX to " << max_shape[i];
shape[i] = max_shape[i];
}
}
std::vector<int64_t> tmp_shape;
std::transform(shape.begin(), shape.end(), std::back_inserter(tmp_shape), SizeToLong);
return UpdateShapeAndType(tmp_shape, AnfAlgo::GetOutputDeviceDataType(anf_node, output_index),
NOT_NULL(output_shape_and_type_[output_index]));
}
bool AicpuExtInfoHandler::GetOutputShapeAndType(uint32_t output_index, NotNull<std::vector<int64_t> *> shape,
NotNull<TypeId *> data_type) {
MS_LOG(INFO) << "Get " << node_name_ << " Output:" << output_index << " Shape And Type";
GetShapeAndType(NOT_NULL(output_shape_and_type_[output_index]), shape, data_type);
return true;
}
bool AicpuExtInfoHandler::UpdateShapeAndType(const std::vector<int64_t> &shape, TypeId data_type,
NotNull<AicpuShapeAndType *> shape_and_type) {
if (shape.empty() || shape.size() > kernel::kMaxShapeDims) {
MS_LOG(ERROR) << "Invalid shape:" << shape.size();
return false;
}
size_t index = 0;
for (; index < shape.size(); ++index) {
shape_and_type->dims[index] = shape[index];
}
if (index < kernel::kMaxShapeDims) {
shape_and_type->dims[index] = kDimEndFlag;
}
// now only support update shape, type is not support
return true;
}
void AicpuExtInfoHandler::GetShapeAndType(NotNull<const AicpuShapeAndType *> shape_and_type,
NotNull<std::vector<int64_t> *> shape, NotNull<TypeId *> data_type) {
for (int64_t tmpDim : shape_and_type->dims) {
if (tmpDim == kDimEndFlag) {
break;
}
shape->emplace_back(tmpDim);
MS_LOG(INFO) << "Debug tmpDim:" << tmpDim;
}
auto ms_type = kernel::AicpuOpUtil::ProtoTypeToMsType(shape_and_type->type);
if (ms_type == -1) {
MS_LOG(EXCEPTION) << "Unspport Proto Type:" << shape_and_type->type;
}
MS_LOG(INFO) << "Debug ms_type:" << ms_type;
*data_type = static_cast<TypeId>(ms_type);
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,88 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_
#include <string>
#include <vector>
#include <utility>
#include <memory>
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
#include "utils/contract.h"
namespace mindspore {
namespace device {
namespace ascend {
// for unknown shape op type
enum UnknowShapeOpType {
DEPEND_IN_SHAPE = 1, // op out shape get by input shape
DEPEND_CONST_VALUE = 2, // op out shape get by const op value
DEPEND_SHAPE_RANGE = 3, // op out shape get by range
DEPEND_COMPUTE = 4 // op out shape get by totally computing
};
using AicpuShapeAndType = kernel::ShapeAndType;
using AicpuExtInfo = kernel::ExtInfo;
class AicpuExtInfoHandler {
public:
AicpuExtInfoHandler(std::string node_name, uint32_t input_num, uint32_t output_num, UnknowShapeOpType unknown_type)
: node_name_(std::move(node_name)),
input_num_(input_num),
output_num_(output_num),
unknown_type_(unknown_type),
ext_info_len_(0) {}
~AicpuExtInfoHandler() = default;
uint8_t *GetExtInfo() const { return ext_info_.get(); }
size_t GetExtInfoLen() const { return ext_info_len_; }
bool Parse(const std::string &ext_info);
bool UpdateInputShapeAndType(uint32_t input_index, const NotNull<AnfNodePtr> &anf_node);
bool UpdateOutputShapeAndType(uint32_t output_index, const NotNull<AnfNodePtr> &anf_node);
bool GetOutputShapeAndType(uint32_t output_index, NotNull<std::vector<int64_t> *> shape, NotNull<TypeId *> data_type);
private:
bool ParseExtShapeType(AicpuExtInfo *aicpu_ext_info);
bool ParseExtInputShape(AicpuExtInfo *aicpu_ext_info);
bool ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info);
static bool UpdateShapeAndType(const std::vector<int64_t> &shape, TypeId data_type,
NotNull<AicpuShapeAndType *> shape_and_type);
static void GetShapeAndType(NotNull<const AicpuShapeAndType *> shape_and_type, NotNull<std::vector<int64_t> *> shape,
NotNull<TypeId *> data_type);
private:
const std::string node_name_;
const uint32_t input_num_;
const uint32_t output_num_;
UnknowShapeOpType unknown_type_;
size_t ext_info_len_;
std::unique_ptr<uint8_t[]> ext_info_;
std::vector<AicpuShapeAndType *> input_shape_and_type_;
std::vector<AicpuShapeAndType *> output_shape_and_type_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_

View File

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/executor_callback.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
namespace ascend {
void ExecutorCallback::RegistCallback(const std::function<void()> &callback) {
std::lock_guard<std::mutex> guard(lock_);
callback_queue_.push(callback);
}
void ExecutorCallback::Consume() {
std::lock_guard<std::mutex> guard(lock_);
while (!callback_queue_.empty()) {
auto callback_func = callback_queue_.front();
callback_queue_.pop();
if (!callback_func) {
MS_LOG(EXCEPTION) << "callback_func is empty";
}
callback_func();
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_
#include <queue>
#include <mutex>
#include <functional>
#include "utils/ms_utils.h"
namespace mindspore {
namespace device {
namespace ascend {
class ExecutorCallback {
public:
static ExecutorCallback &GetInstance() {
static ExecutorCallback instance;
return instance;
}
void RegistCallback(const std::function<void()> &callback);
void Consume();
private:
ExecutorCallback() = default;
~ExecutorCallback() = default;
DISABLE_COPY_AND_ASSIGN(ExecutorCallback);
std::queue<std::function<void()>> callback_queue_;
std::mutex lock_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_

View File

@ -0,0 +1,187 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
#include <dlfcn.h>
#include <vector>
#include "hccl/hcom.h"
#include "common/opskernel/ge_task_info.h"
#include "utils/log_adapter.h"
#include "runtime/device/kernel_runtime.h"
#include "backend/kernel_compiler/hccl/hcom_util.h"
namespace {
// Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/)
constexpr auto kHcomGraphAdaptorPath = "libhcom_graph_adaptor.so";
} // namespace
namespace mindspore {
namespace device {
namespace ascend {
void HcclDynamicKernel::UpdateArgs() {
if (!is_dynamic_shape_) {
MS_LOG(INFO) << "Not Dynamic Shape";
return;
}
MS_LOG(INFO) << "Start to UpdateArgs";
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_mod);
// Update input, output, count
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
if (kernel_inputs.empty() || kernel_outputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs or outputs is empty";
}
auto input0 = kernel_inputs.at(0);
auto output0 = kernel_outputs.at(0);
MS_EXCEPTION_IF_NULL(input0);
MS_EXCEPTION_IF_NULL(output0);
// Update Hccl input and output
input_ptr_ = input0->addr;
output_ptr_ = output0->addr;
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list;
if (!HcomUtil::GetKernelInputShape(cnode_ptr_, &hccl_kernel_input_shape_list)) {
MS_LOG(EXCEPTION) << "GetKernelInputShape fail!";
}
std::vector<HcclDataType> hccl_data_type_list;
if (!HcomUtil::GetHcomDataType(cnode_ptr_, &hccl_data_type_list)) {
MS_LOG(EXCEPTION) << "GetHcomDataType fail!";
}
// Update Hccl count
if (!HcomUtil::GetHcomCount(cnode_ptr_, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) {
MS_LOG(EXCEPTION) << "GetHcomCount fail!";
}
MS_LOG(INFO) << "Update Hccl count:" << count_;
}
void HcclDynamicKernel::StaticShapeExecute() {
MS_EXCEPTION_IF_NULL(cnode_ptr_);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
}
void HcclDynamicKernel::Execute() {
MS_LOG(INFO) << "Start Execute";
if (!is_dynamic_shape_) {
MS_LOG(INFO) << "Not Dynamic, call hcom api";
StaticShapeExecute();
return;
}
auto handle = HcclExecutorManager::GetInstance().handle();
auto EnqueueHcomOperation =
(HcclResult(*)(ge::HcomOpertion, std::function<void(HcclResult status)>))dlsym(handle, "EnqueueHcomOpertion");
if (EnqueueHcomOperation == nullptr) {
MS_LOG(ERROR) << "Failed to get EnqueueHcomOperation function";
if (dlclose(handle) != 0) {
MS_LOG(WARNING) << "Failed to close hcom handle";
}
MS_LOG(EXCEPTION) << "Hccl dynamic kernel execute failed";
return;
}
ge::HcomOpertion op_info;
op_info.hcclType = hccl_type_;
op_info.inputPtr = input_ptr_;
op_info.outputPtr = output_ptr_;
op_info.dataType = data_type_;
op_info.opType = op_type_;
op_info.root = root_;
op_info.count = count_;
auto callback = [this](HcclResult status) {
if (status != HCCL_SUCCESS) {
MS_LOG(ERROR) << "HcomExcutorInitialize failed, ret:" << status;
}
std::lock_guard<std::mutex> lock(this->hccl_mutex_);
this->cond_.notify_all();
MS_LOG(INFO) << "hccl callback success.";
};
auto hccl_ret = EnqueueHcomOperation(op_info, callback);
if (hccl_ret != HCCL_SUCCESS) {
MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed";
}
std::unique_lock<std::mutex> ulock(hccl_mutex_);
cond_.wait(ulock);
MS_LOG(INFO) << "Execute success";
}
void HcclDynamicKernel::PostExecute() {}
bool HcclExecutorManager::Initialize() {
if (initialized_) {
return true;
}
initialized_ = true;
MS_LOG(INFO) << "Start Initialize Hccl DynamicKernel";
handle_ = dlopen(kHcomGraphAdaptorPath, RTLD_NOW | RTLD_GLOBAL);
if (handle_ == nullptr) {
MS_LOG(ERROR) << "dlopen failed, path:" << kHcomGraphAdaptorPath;
return false;
}
auto HcomExecutorInitialize = (HcclResult(*)())dlsym(handle_, "HcomExcutorInitialize");
if (HcomExecutorInitialize == nullptr) {
MS_LOG(ERROR) << "dlsym HcomExecutorInitialize failed";
return false;
}
HcclResult hccl_ret = HcomExecutorInitialize();
if (hccl_ret == HCCL_E_PTR) {
MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required";
} else if (hccl_ret == HCCL_SUCCESS) {
MS_LOG(INFO) << "Hcom DynamicKernel Initialize success";
} else {
MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed";
return false;
}
return true;
}
bool HcclExecutorManager::Finalize() {
auto HcomExecutorFinalize = (HcclResult(*)())dlsym(handle_, "HcomExcutorFinalize");
if (HcomExecutorFinalize == nullptr) {
MS_LOG(ERROR) << "Faile to dlsym HcomExecutorFinalize";
return false;
}
HcclResult hccl_ret = HcomExecutorFinalize();
if (hccl_ret != HCCL_SUCCESS) {
MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed";
return false;
}
if (dlclose(handle_) != 0) {
MS_LOG(ERROR) << "Failed to close hcom handle";
return false;
}
MS_LOG(INFO) << "Hccl DynamicKernel Finalize failed";
return true;
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_
#include <condition_variable>
#include <string>
#include "runtime/device/executor/dynamic_kernel.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace device {
namespace ascend {
class HcclDynamicKernel : public DynamicKernel {
public:
HcclDynamicKernel(const std::string &hccl_type, void *input_ptr, void *output_ptr, uint64_t count, int32_t data_type,
int32_t op_type, int32_t root, void *stream, const CNodePtr &cnode_ptr)
: DynamicKernel(stream, cnode_ptr),
hccl_type_(hccl_type),
input_ptr_(input_ptr),
output_ptr_(output_ptr),
count_(count),
data_type_(data_type),
op_type_(op_type),
root_(root) {}
~HcclDynamicKernel() override = default;
void UpdateArgs() override;
void Execute() override;
void PostExecute() override;
private:
std::string hccl_type_;
void *input_ptr_;
void *output_ptr_;
uint64_t count_{0};
int32_t data_type_{0};
int32_t op_type_{0};
int32_t root_{0};
std::mutex hccl_mutex_;
std::condition_variable cond_;
void StaticShapeExecute();
};
class HcclExecutorManager {
public:
static HcclExecutorManager &GetInstance() {
static HcclExecutorManager instance;
return instance;
}
bool Initialize();
bool Finalize();
void *handle() { return handle_; }
private:
HcclExecutorManager() = default;
~HcclExecutorManager() = default;
DISABLE_COPY_AND_ASSIGN(HcclExecutorManager);
void *handle_{nullptr};
bool initialized_{false};
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_
#include "runtime/device/executor/dynamic_kernel.h"
namespace mindspore {
namespace device {
namespace ascend {
class HostDynamicKernel : public DynamicKernel {
public:
HostDynamicKernel(void *stream, const CNodePtr &cnode_ptr) : DynamicKernel(stream, cnode_ptr) {}
~HostDynamicKernel() override = default;
void UpdateArgs() override {}
void Execute() override = 0;
void PostExecute() override {}
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
#include "runtime/mem.h"
namespace mindspore {
namespace device {
namespace ascend {
void MemcpyRtsDynamicKernel::Execute() {
auto status = rtMemcpyAsync(dst_, dest_max_, src_, count_, RT_MEMCPY_DEVICE_TO_DEVICE, stream_);
if (status != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "MemCpyAsync op rtMemcpyAsync failed!";
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_
#include "runtime/device/executor/dynamic_kernel.h"
namespace mindspore {
namespace device {
namespace ascend {
class MemcpyRtsDynamicKernel : public DynamicKernel {
public:
MemcpyRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, void *dst, uint32_t dest_max, void *src,
uint32_t count)
: DynamicKernel(stream, cnode_ptr), dst_(dst), dest_max_(dest_max), src_(src), count_(count) {}
~MemcpyRtsDynamicKernel() override = default;
void UpdateArgs() override {}
void Execute() override;
void PostExecute() override {}
private:
void *dst_;
uint32_t dest_max_;
void *src_;
uint32_t count_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
#include "runtime/base.h"
namespace mindspore {
namespace device {
namespace ascend {
void ProfilingRtsDynamicKernel::Execute() {
auto rt_ret = rtProfilerTrace(log_id_, notify_, flags_, stream_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtProfilerTrace failed";
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_
#include "runtime/device/executor/dynamic_kernel.h"
namespace mindspore {
namespace device {
namespace ascend {
class ProfilingRtsDynamicKernel : public DynamicKernel {
public:
ProfilingRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, uint64_t log_id, bool notify, uint32_t flags)
: DynamicKernel(stream, cnode_ptr), log_id_(log_id), notify_(notify), flags_(flags) {}
~ProfilingRtsDynamicKernel() override = default;
void UpdateArgs() override {}
void Execute() override;
void PostExecute() override {}
private:
uint64_t log_id_;
bool notify_;
uint32_t flags_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_

View File

@ -0,0 +1,188 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
#include <dlfcn.h>
#include <map>
#include <vector>
#include <memory>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/ge_types_convert.h"
#include "utils/utils.h"
#include "external/graph/tensor.h"
namespace mindspore {
namespace device {
namespace ascend {
ge::Tensor MakeTempGeTensor(TypeId type_id) {
auto ge_type = GeTypesConvert::TransTypeIdToGeDataType(type_id);
ge::TensorDesc tensor_desc;
tensor_desc.SetDataType(ge_type);
ge::Tensor ge_tensor;
ge_tensor.SetTensorDesc(tensor_desc);
return ge_tensor;
}
void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode,
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
MS_LOG(INFO) << "FeedTeOpTensorInputArg start, node:" << cnode->fullname_with_scope();
auto input_size = AnfAlgo::GetInputTensorNum(cnode.get());
// Skip Dynamic Shape Depend Input
for (size_t i = 0; i < input_size; ++i) {
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode.get(), i);
auto input_node = input_node_with_index.first;
auto input_index = input_node_with_index.second;
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index);
auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index);
auto iter = type_name_map.find(output_dtype);
if (iter == type_name_map.end()) {
MS_LOG(EXCEPTION) << "Cannot found typeId:" << output_dtype;
}
auto ge_output_dtype = iter->second;
optiling::TeOpTensorArg tensor_arg;
optiling::TeOpTensor tensor;
tensor_arg.arg_type = optiling::TA_SINGLE;
tensor.dtype = ge_output_dtype;
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
tensor_arg.tensor.emplace_back(tensor);
tensor_arg_list->emplace_back(tensor_arg);
}
}
void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
MS_LOG(INFO) << "FeedTeOpTensorOutputArg start, node:" << cnode->fullname_with_scope();
auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get());
for (size_t i = 0; i < output_size; ++i) {
auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i);
auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i);
auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i);
auto iter = type_name_map.find(data_type);
if (iter == type_name_map.end()) {
MS_LOG(EXCEPTION) << "Cannot found typeId:" << data_type;
}
optiling::TeOpTensorArg tensor_arg;
optiling::TeOpTensor tensor;
tensor_arg.arg_type = optiling::TA_SINGLE;
tensor.dtype = iter->second;
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
tensor_arg.tensor.emplace_back(tensor);
tensor_arg_list->emplace_back(tensor_arg);
}
}
void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) {
MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope();
if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) {
MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope();
return;
}
auto depends_list = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode.get(), kDynamicShapeDepends);
for (auto index : depends_list) {
auto iter = depend_tensor_map.find(IntToSize(index));
if (iter == depend_tensor_map.end()) {
MS_LOG(EXCEPTION) << "Index not found in depend_tensor_map";
}
auto const_tensor = iter->second;
auto have_input_names_attr = AnfAlgo::HasNodeAttr("input_names", cnode);
if (!have_input_names_attr) {
MS_LOG(EXCEPTION) << "cnode:" << cnode->fullname_with_scope() << " no input_names attr";
}
auto input_names_attr = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode.get(), "input_names");
if (IntToSize(index) >= input_names_attr.size()) {
MS_LOG(EXCEPTION) << "input index" << index << " >= input_name_attr.size:" << input_names_attr.size();
}
auto input_name = input_names_attr[index];
MS_LOG(INFO) << "input_name is " << input_name;
auto type_id = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode.get(), index);
const_inputs->try_emplace(
input_name, optiling::TeConstTensorData{static_cast<const uint8_t *>(const_tensor->data_c()),
IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id)});
}
MS_LOG(INFO) << "FeedTeOpConstTensor end";
}
void OpTilingCalculater::Init() {
MS_LOG(INFO) << "Start init OpTilingCalculater";
tiling_func_map_ = optiling::OpTilingInterf::RegisteredOpInterf();
MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size();
for (const auto &iter : tiling_func_map_) {
MS_LOG(INFO) << "Regist tiling func:" << iter.first;
}
}
std::string GetRealOpType(const std::string &op_type) {
static const std::map<std::string, std::string> kOpTypeMap = {
{"SparseApplyFtrl", "SparseApplyFtrlD"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {
return op_type;
}
return iter->second;
}
void OpTilingCalculater::CalculateTiling(const NotNull<CNodePtr> &cnode,
const NotNull<std::shared_ptr<nlohmann::json>> &compile_info_json,
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<optiling::OpRunInfo *> op_run_info) {
optiling::TeOpParas op_param;
std::string op_type = AnfAlgo::GetCNodeName(cnode.get());
MS_LOG(INFO) << "[DynamicShape] calculate tiling, op_type:" << op_type;
FeedTeOpTensorInputArg(cnode, NOT_NULL(&op_param.inputs));
FeedTeOpTensorOutputArg(cnode, NOT_NULL(&op_param.outputs));
FeedTeOpConstTensor(cnode, depend_tensor_map, NOT_NULL(&op_param.const_inputs));
op_type = GetRealOpType(op_type);
auto iter = tiling_func_map_.find(op_type);
if (iter == tiling_func_map_.end()) {
iter = tiling_func_map_.find("AutoTiling");
if (iter == tiling_func_map_.end()) {
MS_LOG(EXCEPTION) << "AutoTiling Func Not Found";
}
}
MS_LOG(INFO) << "Get tiling func:" << iter->first;
if (iter != tiling_func_map_.end()) {
bool ret = (iter->second)(op_type, op_param, *compile_info_json.get(), *op_run_info);
if (!ret) {
MS_LOG(EXCEPTION) << "Calculate tiling failed";
}
} else {
MS_LOG(EXCEPTION) << "Tiling func not found";
}
MS_LOG(INFO) << "CalculateTiling success";
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
#include <map>
#include <memory>
#include <string>
#include "utils/ms_utils.h"
#include "utils/contract.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "register/op_tiling.h"
namespace mindspore {
namespace device {
namespace ascend {
class OpTilingCalculater {
public:
static OpTilingCalculater &GetInstance() {
static OpTilingCalculater instance;
return instance;
}
void Init();
void CalculateTiling(const NotNull<CNodePtr> &cnode,
const NotNull<std::shared_ptr<nlohmann::json>> &compile_info_json,
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
NotNull<optiling::OpRunInfo *> op_run_info);
private:
OpTilingCalculater() = default;
~OpTilingCalculater() = default;
DISABLE_COPY_AND_ASSIGN(OpTilingCalculater);
std::map<std::string, optiling::OpTilingFunc> tiling_func_map_;
};
} // namespace ascend
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_

View File

@ -0,0 +1,137 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/ascend/ge_types_convert.h"
namespace mindspore {
namespace device {
namespace ascend {
ge::proto::DataType GeTypesConvert::GetGeDataType(TypeId type_id) {
static const std::map<TypeId, ge::proto::DataType> data_type_map = {
{TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
{TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
{TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
{TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
{TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
{TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
{TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE},
};
MS_LOG(INFO) << "Vm origin type_id:" << type_id;
auto iter = data_type_map.find(type_id);
if (iter == data_type_map.end()) {
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
}
return iter->second;
}
ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) {
static const std::map<TypeId, ge::DataType> data_type_map = {
{TypeId::kNumberTypeFloat, ge::DataType::DT_FLOAT}, {TypeId::kNumberTypeFloat32, ge::DataType::DT_FLOAT},
{TypeId::kNumberTypeFloat16, ge::DataType::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::DataType::DT_INT8},
{TypeId::kNumberTypeInt16, ge::DataType::DT_INT16}, {TypeId::kNumberTypeUInt16, ge::DataType::DT_UINT16},
{TypeId::kNumberTypeUInt8, ge::DataType::DT_UINT8}, {TypeId::kNumberTypeInt32, ge::DataType::DT_INT32},
{TypeId::kNumberTypeInt, ge::DataType::DT_INT32}, {TypeId::kNumberTypeInt64, ge::DataType::DT_INT64},
{TypeId::kNumberTypeUInt32, ge::DataType::DT_UINT32}, {TypeId::kNumberTypeUInt, ge::DataType::DT_UINT32},
{TypeId::kNumberTypeUInt64, ge::DataType::DT_UINT64}, {TypeId::kNumberTypeBool, ge::DataType::DT_BOOL},
{TypeId::kNumberTypeInt64, ge::DataType::DT_DOUBLE}, {TypeId::kTypeUnknown, ge::DataType::DT_UNDEFINED}};
auto iter = data_type_map.find(type_id);
if (iter == data_type_map.end()) {
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
}
return iter->second;
}
GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
static const std::map<std::string, GeFormat> format_map = {
// default format: nchw, fractal_nz?
{kOpFormat_DEFAULT, kFormat_NCHW},
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0},
{kOpFormat_ND, kFormat_ND},
{kOpFormat_NCHW, kFormat_NCHW},
{kOpFormat_NHWC, kFormat_NHWC},
{kOpFormat_HWCN, kFormat_HWCN},
{kOpFormat_NC1HWC0, kFormat_NC1HWC0},
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z},
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ},
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0},
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04},
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04},
{kOpFormat_NDHWC, kFormat_NDHWC},
};
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
if (format == kOpFormat_DEFAULT) {
return shape_size == 4 ? kFormat_NCHW : kFormat_ND;
}
auto iter = format_map.find(format);
if (iter == format_map.end()) {
MS_LOG(EXCEPTION) << "Invalid format:" << format;
}
return iter->second;
}
std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) {
static const std::map<GeFormat, std::string> kFormatToStringMap = {
{kFormat_NCHW, "NCHW"},
{kFormat_NHWC, "NHWC"},
{kFormat_ND, "ND"},
{kFormat_NC1HWC0, "NC1HWC0"},
{kFormat_FRACTAL_Z, "FRACTAL_Z"},
{kFormat_NC1C0HWPAD, "NC1C0HWPAD"},
{kFormat_NHWC1C0, "NHWC1C0"},
{kFormat_FSR_NCHW, "FSR_NCHW"},
{kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"},
{kFormat_C1HWNC0, "C1HWNC0"},
{kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
{kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
{kFormat_NC1HWC0_C04, "NC1HWC0_C04"},
{kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
{kFormat_CHWN, "CHWN"},
{kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
{kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"},
{kFormat_BN_WEIGHT, "BN_WEIGHT"},
{kFormat_FILTER_HWCK, "FILTER_HWCK"},
{kFormat_HWCN, "HWCN"},
{kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
{kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
{kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
{kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
{kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
{kFormat_MD, "MD"},
{kFormat_NDHWC, "NDHWC"},
{kFormat_NCDHW, "NCDHW"},
{kFormat_DHWCN, "DHWCN"},
{kFormat_DHWNC, "DHWNC"},
{kFormat_NDC1HWC0, "NDC1HWC0"},
{kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
{kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
{kFormat_C1HWNCoC0, "C1HWNCoC0"},
{kFormat_FRACTAL_NZ, "FRACTAL_NZ"},
{kFormat_CN, "CN"},
{kFormat_NC, "NC"},
{kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
{kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"},
{kFormat_RESERVED, "FORMAT_RESERVED"},
{kFormat_ALL, "ALL"}};
auto iter = kFormatToStringMap.find(ge_format);
if (iter == kFormatToStringMap.end()) {
MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format;
}
return iter->second;
}
} // namespace ascend
} // namespace device
} // namespace mindspore

View File

@ -22,28 +22,11 @@
#include "proto/ge_dtype.pb.h" #include "proto/ge_dtype.pb.h"
#include "ir/dtype/type_id.h" #include "ir/dtype/type_id.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "external/graph/types.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
static ge::proto::DataType GetGeDataType(TypeId type_id) {
static const std::map<TypeId, ge::proto::DataType> data_type_map = {
{TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
{TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
{TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
{TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
{TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
{TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
{TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE},
};
MS_LOG(INFO) << "Vm origin type_id:" << type_id;
auto iter = data_type_map.find(type_id);
if (iter == data_type_map.end()) {
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
}
return iter->second;
}
enum GeFormat { enum GeFormat {
kFormat_NCHW = 0, // NCHW kFormat_NCHW = 0, // NCHW
kFormat_NHWC, // NHWC kFormat_NHWC, // NHWC
@ -83,37 +66,21 @@ enum GeFormat {
kFormat_NC, kFormat_NC,
kFormat_DHWNC, kFormat_DHWNC,
kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
kFormat_FRACTAL_ZN_LSTM,
kFormat_FRACTAL_Z_G,
kFormat_RESERVED, kFormat_RESERVED,
kFormat_ALL kFormat_ALL
}; };
static GeFormat GetGeFormat(const std::string &format, size_t shape_size) { class GeTypesConvert {
static const std::map<std::string, GeFormat> format_map = { public:
// default format: nchw, fractal_nz? GeTypesConvert() = default;
{kOpFormat_DEFAULT, kFormat_NCHW}, ~GeTypesConvert() = default;
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, static ge::proto::DataType GetGeDataType(TypeId type_id);
{kOpFormat_ND, kFormat_ND}, static GeFormat GetGeFormat(const std::string &format, size_t shape_size);
{kOpFormat_NCHW, kFormat_NCHW}, static std::string GetGeTilingFormat(GeFormat ge_format);
{kOpFormat_NHWC, kFormat_NHWC}, static ge::DataType TransTypeIdToGeDataType(TypeId type_id);
{kOpFormat_HWCN, kFormat_HWCN}, };
{kOpFormat_NC1HWC0, kFormat_NC1HWC0},
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z},
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ},
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0},
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04},
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04},
{kOpFormat_NDHWC, kFormat_NDHWC},
};
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
if (format == kOpFormat_DEFAULT) {
return shape_size == 4 ? kFormat_NCHW : kFormat_ND;
}
auto iter = format_map.find(format);
if (iter == format_map.end()) {
MS_LOG(EXCEPTION) << "Invalid format:" << format;
}
return iter->second;
}
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -27,6 +27,7 @@
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" #include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h"
#include "backend/kernel_compiler/host/host_kernel_build.h"
#include "backend/kernel_compiler/hccl/hccl_kernel_build.h" #include "backend/kernel_compiler/hccl/hccl_kernel_build.h"
#include "backend/kernel_compiler/rts/rt_kernel_build.h" #include "backend/kernel_compiler/rts/rt_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h"
@ -47,6 +48,10 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
break; break;
} }
case KernelType::HOST_KERNEL: {
kernel_mod_ptr = kernel::HostOpBuild(anf_node);
break;
}
case KernelType::RT_KERNEL: { case KernelType::RT_KERNEL: {
kernel_mod_ptr = kernel::RtOpBuild(anf_node); kernel_mod_ptr = kernel::RtOpBuild(anf_node);
break; break;

View File

@ -22,6 +22,10 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include <unordered_map>
#include <unordered_set>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@ -493,7 +497,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
} }
// we set special device info of a input tensor. // we set special device info of a input tensor.
bool is_ref = false; bool is_ref = false;
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
if (op_info != nullptr) { if (op_info != nullptr) {
is_ref = op_info->is_ref(); is_ref = op_info->is_ref();
} }

View File

@ -44,6 +44,8 @@ class CPUKernelRuntime : public KernelRuntime {
VectorRef *outputs); VectorRef *outputs);
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
protected: protected:
bool SyncStream() override { return true; }; bool SyncStream() override { return true; };

View File

@ -0,0 +1,128 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/executor/dynamic_kernel.h"
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "common/trans.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "abstract/dshape.h"
#include "abstract/param_validator.h"
namespace mindspore {
namespace device {
void DynamicKernel::Initialize() {
MS_LOG(INFO) << "Init Start";
is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode_ptr_);
if (!is_dynamic_shape_) {
MS_LOG(INFO) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope();
return;
}
is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape);
is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape);
auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_);
if (!have_depends) {
MS_LOG(WARNING) << "No dynamic_shape_depends found";
return;
}
MS_LOG(INFO) << "Have depends";
auto depends_list = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynamicShapeDepends);
// Save depend input tensor. Sync data in InferShape.
for (auto depend : depends_list) {
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend);
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend);
std::vector<int> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
out_tensor->set_device_address(output_addr);
auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
if (!ret.second) {
MS_LOG(EXCEPTION) << "Insert map failed";
}
}
MS_LOG(INFO) << "Init End";
}
bool IsTupleGetItem(const AnfNodePtr &anf_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
return false;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
return IsPrimitive(input0, prim::kPrimTupleGetItem);
}
void DynamicKernel::InferShape() {
if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) {
return;
}
MS_EXCEPTION_IF_NULL(cnode_ptr_);
MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope();
auto inputs = cnode_ptr_->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
AbstractBasePtrList args_spec_list;
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_);
for (size_t i = 0; i < input_size; ++i) {
auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i);
auto real_input = input_with_index.first;
MS_EXCEPTION_IF_NULL(real_input);
auto ret = depend_tensor_map_.find(i);
if (ret != depend_tensor_map_.end()) {
auto tensor_ptr = ret->second;
MS_EXCEPTION_IF_NULL(tensor_ptr);
// sync data from device to host
tensor_ptr->data_sync();
real_input->abstract()->set_value(tensor_ptr);
}
auto cnode_input = cnode_ptr_->input(i + 1);
MS_EXCEPTION_IF_NULL(cnode_input);
if (IsTupleGetItem(cnode_input)) {
auto base_shape = real_input->Shape();
if (!base_shape->isa<abstract::TupleShape>()) {
MS_LOG(EXCEPTION) << "Node:" << cnode_ptr_->fullname_with_scope()
<< " input is a tuple_get_item but real input node shape is not a TupleShape";
}
auto tuple_ptr = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
auto real_shape = tuple_ptr->shape().at(tuple_get_item_index);
auto abstract_tensor = cnode_input->abstract()->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(abstract_tensor);
args_spec_list.emplace_back(std::make_shared<abstract::AbstractTensor>(abstract_tensor->element(), real_shape));
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
args_spec_list.emplace_back(cnode_input->abstract());
} else {
args_spec_list.emplace_back(real_input->abstract());
}
}
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
cnode_ptr_->set_abstract(eval_result);
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_
#include <memory>
#include <string>
#include <map>
#include "ir/anf.h"
#include "ir/tensor.h"
namespace mindspore {
namespace device {
constexpr auto kDynamicShapeDepends = "dynamic_shape_depends";
class DynamicKernel {
public:
DynamicKernel(void *stream, const CNodePtr &cnode_ptr)
: stream_(stream),
cnode_ptr_(cnode_ptr),
is_dynamic_shape_(false),
is_input_dynamic_shape_(false),
is_output_dynamic_shape_(false) {}
virtual ~DynamicKernel() = default;
virtual void InferShape();
virtual void UpdateArgs() = 0;
virtual void Execute() = 0;
virtual void PostExecute() = 0;
bool is_dynamic_shape() const { return is_dynamic_shape_; }
bool is_input_dynamic_shape() const { return is_input_dynamic_shape_; }
bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; }
bool have_depends() const { return !depend_tensor_map_.empty(); }
virtual void Initialize();
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }
protected:
void *stream_;
const CNodePtr cnode_ptr_;
bool is_dynamic_shape_;
bool is_input_dynamic_shape_;
bool is_output_dynamic_shape_;
std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_;
};
using DynamicKernelPtr = std::shared_ptr<DynamicKernel>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_

View File

@ -43,6 +43,8 @@ class GPUKernelRuntime : public KernelRuntime {
const std::vector<CNodePtr> &execution_order) override; const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override; void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
protected: protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

Some files were not shown because too many files have changed in this diff Show More